|
import os |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms, utils |
|
import pretty_midi |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
import math |
|
from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO |
|
|
|
import matplotlib.pyplot as plt |
|
plt.rcParams["figure.figsize"] = (6,3) |
|
plt.rcParams['figure.dpi'] = 300 |
|
plt.rcParams['savefig.dpi'] = 300 |
|
|
|
CC_SUSTAIN_PEDAL = 64 |
|
|
|
|
|
def split_csv(csv_path='merged_midi.csv'): |
|
|
|
df = pd.read_csv(csv_path) |
|
save_name = csv_path[:csv_path.rfind('.csv')] |
|
for split in ['train', 'validation', 'test']: |
|
path = os.path.join(save_name, split + '.csv') |
|
df_sub = df[df.split == split] |
|
df_sub.to_csv(path, index=False) |
|
return |
|
|
|
|
|
def quantize_pedal(value, num_bins=8): |
|
"""Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin.""" |
|
if value < 0 or value > 127: |
|
raise ValueError("Value should be between 0 and 127") |
|
|
|
bin_size = 128 // num_bins |
|
|
|
bin_index = value // bin_size |
|
bin_center = bin_size * bin_index + bin_size // 2 |
|
|
|
if bin_center > 127: |
|
bin_center = 127 |
|
return bin_center |
|
|
|
|
|
def get_full_piano_roll(midi_data, fs, show=False): |
|
|
|
piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True) |
|
|
|
pedal_roll = np.zeros_like(piano_roll) |
|
|
|
for instru in midi_data.instruments: |
|
pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL] |
|
for cc in pedal_changes: |
|
time_now = int(cc.time * fs) |
|
if time_now < pedal_roll.shape[-1]: |
|
|
|
|
|
if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64: |
|
|
|
pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value) |
|
else: |
|
pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value) |
|
full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0) |
|
if show: |
|
plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127) |
|
plt.show() |
|
plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127) |
|
plt.show() |
|
return full_roll |
|
|
|
|
|
def preprocess_midi(target='merged', csv_path='merged_midi.csv', fs=100., image_size=128, overlap=False, show=False): |
|
|
|
df = pd.read_csv(csv_path) |
|
total_pieces = len(df) |
|
if not os.path.exists(target): |
|
os.makedirs(target) |
|
for split in ['train', 'test']: |
|
path = os.path.join(target, split) |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
for i in tqdm(range(total_pieces)): |
|
midi_filename = df.midi_filename[i] |
|
split = df.split[i] |
|
dataset = df.dataset[i] |
|
path = os.path.join(target, split) |
|
midi_data = pretty_midi.PrettyMIDI(os.path.join(dataset, midi_filename)) |
|
full_roll = get_full_piano_roll(midi_data, fs=fs, show=show) |
|
for j in range(0, full_roll.shape[-1], image_size): |
|
if j + image_size <= full_roll.shape[-1]: |
|
full_roll_excerpt = full_roll[:, :, j:j + image_size] |
|
else: |
|
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) |
|
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:] |
|
empty_roll = math.isclose(full_roll_excerpt.max(), 0.) |
|
if not empty_roll: |
|
|
|
last_slash_index = midi_filename.rfind('/') |
|
|
|
dot_mid_index = midi_filename.rfind('.mid') |
|
|
|
save_name = midi_filename[last_slash_index + 1:dot_mid_index] |
|
full_roll_excerpt = full_roll_excerpt.astype(np.uint8) |
|
np.save(os.path.join(path, save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt) |
|
|
|
|
|
if overlap: |
|
for j in range(image_size//2, full_roll.shape[-1], image_size): |
|
if j + image_size <= full_roll.shape[-1]: |
|
full_roll_excerpt = full_roll[:, :, j:j + image_size] |
|
else: |
|
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) |
|
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:] |
|
empty_roll = math.isclose(full_roll_excerpt.max(), 0.) |
|
if not empty_roll: |
|
last_slash_index = midi_filename.rfind('/') |
|
dot_mid_index = midi_filename.rfind('.mid') |
|
save_name = midi_filename[last_slash_index + 1:dot_mid_index] |
|
full_roll_excerpt = full_roll_excerpt.astype(np.uint8) |
|
np.save(os.path.join(path, 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt) |
|
|
|
|
|
return |
|
|
|
|
|
def main(): |
|
|
|
preprocess_midi(target='all-128-fs100', csv_path='all_midi.csv', fs=100, image_size=128, overlap=False, show=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |