|
"""preprocess_maps.py""" |
|
import os |
|
import glob |
|
import re |
|
import json |
|
from typing import Dict, List, Tuple |
|
import numpy as np |
|
from utils.audio import get_audio_file_info |
|
from utils.midi import midi2note, note_event2midi |
|
from utils.note2event import note2note_event, note_event2event |
|
from utils.event2note import event2note_event |
|
from utils.note_event_dataclasses import Note, NoteEvent |
|
from utils.utils import note_event2token2note_event_sanity_check |
|
|
|
|
|
|
|
def create_note_event_and_note_from_midi(mid_file: str, |
|
id: str, |
|
ignore_pedal: bool = False) -> Tuple[Dict, Dict]: |
|
"""Extracts note or note_event and metadata from midi: |
|
|
|
Returns: |
|
notes (dict): note events and metadata. |
|
note_events (dict): note events and metadata. |
|
""" |
|
notes, dur_sec = midi2note( |
|
mid_file, |
|
binary_velocity=True, |
|
ch_9_as_drum=False, |
|
force_all_drum=False, |
|
force_all_program_to=0, |
|
trim_overlap=True, |
|
fix_offset=True, |
|
quantize=True, |
|
verbose=0, |
|
minimum_offset_sec=0.01, |
|
drum_offset_sec=0.01, |
|
ignore_pedal=ignore_pedal) |
|
return { |
|
'maps_id': id, |
|
'program': [0], |
|
'is_drum': [0], |
|
'duration_sec': dur_sec, |
|
'notes': notes, |
|
}, { |
|
'maps_id': id, |
|
'program': [0], |
|
'is_drum': [0], |
|
'duration_sec': dur_sec, |
|
'note_events': note2note_event(notes), |
|
} |
|
|
|
|
|
def rewrite_midi_120bpm(file: os.PathLike, note_events: List[NoteEvent]): |
|
"""Rewrite midi file with 120 bpm.""" |
|
note_event2midi(note_events, file) |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_maps16k(data_home=os.PathLike, |
|
dataset_name='maps', |
|
ignore_pedal=False, |
|
sanity_check=False) -> None: |
|
""" |
|
Splits: |
|
- train: following the convention described in Cheuk et al. (2021), |
|
we filter out the songs overlapping with the MAPS test set. |
|
139 pieces from MUS folder are left for training. |
|
- test: 60 files (MUS) |
|
- all: 270 files including (unfiltered) train and test. This is used |
|
for the evaluation on the MusicNet test set. |
|
|
|
|
|
Writes: |
|
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: |
|
{ |
|
index: |
|
{ |
|
'maps_id': maps_id, |
|
'n_frames': (int), |
|
'mix_audio_file': 'path/to/mix.wav', |
|
'notes_file': 'path/to/notes.npy', |
|
'note_events_file': 'path/to/note_events.npy', |
|
'midi_file': 'path/to/midi.mid', |
|
'program': List[int], |
|
'is_drum': List[int], # 0 or 1 |
|
} |
|
} |
|
""" |
|
|
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
train_mid_pattern = os.path.join(base_dir, 'train/**/MUS/*.mid') |
|
test_mid_pattern = os.path.join(base_dir, 'test/**/MUS/*.mid') |
|
all_mid_pattern = os.path.join(base_dir, '**/MUS/*.mid') |
|
|
|
train_mid_files = glob.glob(train_mid_pattern, recursive=True) |
|
test_mid_files = glob.glob(test_mid_pattern, recursive=True) |
|
all_mid_files = glob.glob(all_mid_pattern, recursive=True) |
|
|
|
|
|
songnames_in_test_files = [] |
|
for file in test_mid_files: |
|
filename = os.path.basename(file) |
|
match = re.search(r"MAPS_MUS-([\w-]+)_", filename) |
|
if match: |
|
songnames_in_test_files.append(match.group(1)) |
|
|
|
filtered_train_mid_files = [] |
|
filtered_train_wav_files = [] |
|
for train_file in train_mid_files: |
|
if not any( |
|
songname in os.path.basename(train_file) for songname in songnames_in_test_files): |
|
filtered_train_mid_files.append(train_file) |
|
filtered_train_wav_files.append(train_file.replace('.mid', '.wav')) |
|
assert len(filtered_train_mid_files) == len(filtered_train_wav_files) == 139 |
|
|
|
|
|
for i, mid_file in enumerate(all_mid_files): |
|
maps_id = os.path.basename(mid_file)[:-4] |
|
notes, note_events = create_note_event_and_note_from_midi( |
|
mid_file=mid_file, id=maps_id, ignore_pedal=ignore_pedal) |
|
|
|
if sanity_check: |
|
|
|
print(f'Sanity check for {i}: {maps_id}...') |
|
note_event2token2note_event_sanity_check(note_events['note_events'], notes['notes']) |
|
|
|
notes_file = mid_file.replace('.mid', '_notes.npy') |
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
print(f'Created {notes_file}') |
|
|
|
note_events_file = mid_file.replace('.mid', '_note_events.npy') |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
print(f'Created {note_events_file}') |
|
|
|
|
|
rewrite_midi_120bpm(mid_file, note_events['note_events']) |
|
print(f'Overwrote {mid_file} with 120 bpm') |
|
|
|
|
|
pass |
|
|
|
|
|
mid_files_by_split = { |
|
'train': filtered_train_mid_files, |
|
'test': test_mid_files, |
|
'all': all_mid_files, |
|
} |
|
|
|
for split in ['train', 'test', 'all']: |
|
file_list = {} |
|
for i, mid_file in enumerate(mid_files_by_split[split]): |
|
|
|
wav_file = mid_file.replace('.mid', '.wav') |
|
if not os.path.exists(wav_file): |
|
raise FileNotFoundError(f'Wav file not found: {wav_file}') |
|
|
|
file_list[i] = { |
|
'maps_id': os.path.basename(mid_file)[:-4], |
|
'n_frames': get_audio_file_info(wav_file)[1], |
|
'mix_audio_file': wav_file, |
|
'notes_file': mid_file.replace('.mid', '_notes.npy'), |
|
'note_events_file': mid_file.replace('.mid', '_note_events.npy'), |
|
'midi_file': mid_file, |
|
'program': [0], |
|
'is_drum': [0], |
|
} |
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|