rule-guided-music / datasets /piano_roll_all.py
yjhuangcd
First commit
9965bf6
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pretty_midi
import pandas as pd
import numpy as np
from tqdm import tqdm
import math
from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6,3)
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
CC_SUSTAIN_PEDAL = 64
def split_csv(csv_path='merged_midi.csv'):
# separate training validation testing files
df = pd.read_csv(csv_path)
save_name = csv_path[:csv_path.rfind('.csv')]
for split in ['train', 'validation', 'test']:
path = os.path.join(save_name, split + '.csv')
df_sub = df[df.split == split]
df_sub.to_csv(path, index=False)
return
def quantize_pedal(value, num_bins=8):
"""Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
if value < 0 or value > 127:
raise ValueError("Value should be between 0 and 127")
# Determine bin size
bin_size = 128 // num_bins # 16
# Quantize the value
bin_index = value // bin_size
bin_center = bin_size * bin_index + bin_size // 2
# Handle edge case for the last bin
if bin_center > 127:
bin_center = 127
return bin_center
def get_full_piano_roll(midi_data, fs, show=False):
# do not process sustain pedal
piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
# save pedal roll explicitly
pedal_roll = np.zeros_like(piano_roll)
# process pedal
for instru in midi_data.instruments:
pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
for cc in pedal_changes:
time_now = int(cc.time * fs)
if time_now < pedal_roll.shape[-1]:
# need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
# in muscore files, 0 immediately followed by 127, need to shift by one column
if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
# use shift 2 here to prevent missing change when using interpolation augmentation
pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
else:
pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
if show:
plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
plt.show()
plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
plt.show()
return full_roll
def preprocess_midi(target='merged', csv_path='merged_midi.csv', fs=100., image_size=128, overlap=False, show=False):
# get piano roll from midi file
df = pd.read_csv(csv_path)
total_pieces = len(df)
if not os.path.exists(target):
os.makedirs(target)
for split in ['train', 'test']:
path = os.path.join(target, split)
if not os.path.exists(path):
os.makedirs(path)
for i in tqdm(range(total_pieces)):
midi_filename = df.midi_filename[i]
split = df.split[i]
dataset = df.dataset[i]
path = os.path.join(target, split)
midi_data = pretty_midi.PrettyMIDI(os.path.join(dataset, midi_filename))
full_roll = get_full_piano_roll(midi_data, fs=fs, show=show)
for j in range(0, full_roll.shape[-1], image_size):
if j + image_size <= full_roll.shape[-1]:
full_roll_excerpt = full_roll[:, :, j:j + image_size]
else:
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) # 2x128ximage_size
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
if not empty_roll:
# Find the last '/' in the string
last_slash_index = midi_filename.rfind('/')
# Find the '.npy' in the string
dot_mid_index = midi_filename.rfind('.mid')
# Extract the substring between last '/' and '.mid'
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
np.save(os.path.join(path, save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
# save with dataset name for VAE duplicate file names
# np.save(os.path.join(path, dataset + '_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
if overlap:
for j in range(image_size//2, full_roll.shape[-1], image_size): # overlap with image_size//2
if j + image_size <= full_roll.shape[-1]:
full_roll_excerpt = full_roll[:, :, j:j + image_size]
else:
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size))
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
if not empty_roll:
last_slash_index = midi_filename.rfind('/')
dot_mid_index = midi_filename.rfind('.mid')
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
np.save(os.path.join(path, 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
# save with dataset name for VAE duplicate file names
# np.save(os.path.join(path, dataset + '_' + 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
return
def main():
# create fs=100 1.28s datasets without overlap (can be rearranged)
preprocess_midi(target='all-128-fs100', csv_path='all_midi.csv', fs=100, image_size=128, overlap=False, show=False)
# create fs=100 2.56s datasets with overlap (used for vae training), when load in, need to select 1.28s from 2.56s
# preprocess_midi(target='all-256-overlap-fs100', csv_path='all_midi.csv', fs=100, image_size=256, overlap=True,
# show=False)
# create fs=12.5 (0.08s) for pixel space diffusion model, rearrangement with length 2
# preprocess_midi(target='all-128-fs12.5', csv_path='all_midi.csv', fs=12.5, image_size=128, overlap=False,
# show=False)
if __name__ == "__main__":
main()