|
from datetime import datetime |
|
from miditok import Event, MIDILike |
|
import os |
|
import json |
|
from time import perf_counter |
|
from joblib import Parallel, delayed |
|
from zipfile import ZipFile, ZIP_DEFLATED |
|
from scipy.io.wavfile import write |
|
import numpy as np |
|
from pydub import AudioSegment |
|
import shutil |
|
|
|
|
|
def writeToFile(path, content): |
|
if type(content) is dict: |
|
with open(f"{path}", "w") as json_file: |
|
json.dump(content, json_file) |
|
else: |
|
if type(content) is not str: |
|
content = str(content) |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
with open(path, "w") as f: |
|
f.write(content) |
|
|
|
|
|
|
|
def readFromFile(path, isJSON=False): |
|
with open(path, "r") as f: |
|
if isJSON: |
|
return json.load(f) |
|
else: |
|
return f.read() |
|
|
|
|
|
def chain(input, funcs, *params): |
|
res = input |
|
for func in funcs: |
|
try: |
|
res = func(res, *params) |
|
except TypeError: |
|
res = func(res) |
|
return res |
|
|
|
|
|
def to_beat_str(value, beat_res=8): |
|
values = [ |
|
int(int(value * beat_res) / beat_res), |
|
int(int(value * beat_res) % beat_res), |
|
beat_res, |
|
] |
|
return ".".join(map(str, values)) |
|
|
|
|
|
def to_base10(beat_str): |
|
integer, decimal, base = split_dots(beat_str) |
|
return integer + decimal / base |
|
|
|
|
|
def split_dots(value): |
|
return list(map(int, value.split("."))) |
|
|
|
|
|
def compute_list_average(l): |
|
return sum(l) / len(l) |
|
|
|
|
|
def get_datetime(): |
|
return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
def get_text(event): |
|
match event.type: |
|
case "Piece-Start": |
|
return "PIECE_START " |
|
case "Track-Start": |
|
return "TRACK_START " |
|
case "Track-End": |
|
return "TRACK_END " |
|
case "Instrument": |
|
return f"INST={event.value} " |
|
case "Bar-Start": |
|
return "BAR_START " |
|
case "Bar-End": |
|
return "BAR_END " |
|
case "Time-Shift": |
|
return f"TIME_SHIFT={event.value} " |
|
case "Note-On": |
|
return f"NOTE_ON={event.value} " |
|
case "Note-Off": |
|
return f"NOTE_OFF={event.value} " |
|
case _: |
|
return "" |
|
|
|
|
|
def get_event(text, value=None): |
|
match text: |
|
case "PIECE_START": |
|
return Event("Piece-Start", value) |
|
case "TRACK_START": |
|
return None |
|
case "TRACK_END": |
|
return None |
|
case "INST": |
|
return Event("Instrument", value) |
|
case "BAR_START": |
|
return Event("Bar-Start", value) |
|
case "BAR_END": |
|
return Event("Bar-End", value) |
|
case "TIME_SHIFT": |
|
return Event("Time-Shift", value) |
|
case "TIME_DELTA": |
|
return Event("Time-Shift", to_beat_str(int(value) / 4)) |
|
case "NOTE_ON": |
|
return Event("Note-On", value) |
|
case "NOTE_OFF": |
|
return Event("Note-Off", value) |
|
case _: |
|
return None |
|
|
|
|
|
|
|
def get_miditok(): |
|
pitch_range = range(0, 140) |
|
beat_res = {(0, 400): 8} |
|
return MIDILike(pitch_range, beat_res) |
|
|
|
|
|
class WriteTextMidiToFile: |
|
def __init__(self, generate_midi, output_path): |
|
self.generated_midi = generate_midi.generated_piece |
|
self.output_path = output_path |
|
self.hyperparameter_and_bars = generate_midi.piece_by_track |
|
|
|
def hashing_seq(self): |
|
self.current_time = get_datetime() |
|
self.output_path_filename = f"{self.output_path}/{self.current_time}.json" |
|
|
|
def wrapping_seq_hyperparameters_in_dict(self): |
|
|
|
|
|
|
|
|
|
return { |
|
"generate_midi": self.generated_midi, |
|
"hyperparameters_and_bars": self.hyperparameter_and_bars, |
|
} |
|
|
|
def text_midi_to_file(self): |
|
self.hashing_seq() |
|
output_dict = self.wrapping_seq_hyperparameters_in_dict() |
|
print(f"Token generate_midi written: {self.output_path_filename}") |
|
writeToFile(self.output_path_filename, output_dict) |
|
return self.output_path_filename |
|
|
|
|
|
def get_files(directory, extension, recursive=False): |
|
""" |
|
Given a directory, get a list of the file paths of all files matching the |
|
specified file extension. |
|
directory: the directory to search as a Path object |
|
extension: the file extension to match as a string |
|
recursive: whether to search recursively in the directory or not |
|
""" |
|
if recursive: |
|
return list(directory.rglob(f"*.{extension}")) |
|
else: |
|
return list(directory.glob(f"*.{extension}")) |
|
|
|
|
|
def timeit(func): |
|
def wrapper(*args, **kwargs): |
|
start = perf_counter() |
|
result = func(*args, **kwargs) |
|
end = perf_counter() |
|
print(f"{func.__name__} took {end - start:.2f} seconds to run.") |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
class FileCompressor: |
|
def __init__(self, input_directory, output_directory, n_jobs=-1): |
|
self.input_directory = input_directory |
|
self.output_directory = output_directory |
|
self.n_jobs = n_jobs |
|
|
|
|
|
def unzip_file(self, file): |
|
"""uncompress single zip file""" |
|
with ZipFile(file, "r") as zip_ref: |
|
zip_ref.extractall(self.output_directory) |
|
|
|
def zip_file(self, file): |
|
"""compress a single text file to a new zip file and delete the original""" |
|
output_file = self.output_directory / (file.stem + ".zip") |
|
with ZipFile(output_file, "w") as zip_ref: |
|
zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED) |
|
file.unlink() |
|
|
|
@timeit |
|
def unzip(self): |
|
"""uncompress all zip files in folder""" |
|
files = get_files(self.input_directory, extension="zip") |
|
Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files) |
|
|
|
@timeit |
|
def zip(self): |
|
"""compress all text files in folder to new zip files and remove the text files""" |
|
files = get_files(self.output_directory, extension="txt") |
|
Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files) |
|
|
|
|
|
def load_jsonl(filepath): |
|
"""Load a jsonl file""" |
|
with open(filepath, "r") as f: |
|
data = [json.loads(line) for line in f] |
|
return data |
|
|
|
|
|
def write_mp3(waveform, output_path, bitrate="92k"): |
|
""" |
|
Write a waveform to an mp3 file. |
|
output_path: Path object for the output mp3 file |
|
waveform: numpy array of the waveform |
|
bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k) |
|
""" |
|
|
|
wav_path = output_path.with_suffix(".wav") |
|
write(wav_path, 44100, waveform.astype(np.float32)) |
|
|
|
AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate) |
|
|
|
wav_path.unlink() |
|
|
|
|
|
def copy_file(input_file, output_dir): |
|
"""Copy an input file to the output_dir""" |
|
output_file = output_dir / input_file.name |
|
shutil.copy(input_file, output_file) |
|
|
|
|
|
def index_has_substring(list, substring): |
|
for i, s in enumerate(list): |
|
if substring in s: |
|
return i |
|
return -1 |
|
|