|
""" preprocess_guitarset.py """ |
|
import os |
|
import glob |
|
import copy |
|
import json |
|
from typing import Dict, List, Tuple, Optional |
|
import numpy as np |
|
import jams |
|
from utils.note_event_dataclasses import Note, NoteEvent |
|
from utils.audio import get_audio_file_info, pitch_shift_audio |
|
from utils.midi import note_event2midi, pitch_shift_midi |
|
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes |
|
|
|
|
|
def create_note_event_and_note_from_jam(jam_file: str, id: str) -> Tuple[Dict, Dict]: |
|
jam = jams.load(jam_file) |
|
notes = [] |
|
for ann in jam.annotations: |
|
for obs in ann.data: |
|
if isinstance(obs.value, float): |
|
if obs.confidence == None: |
|
note = Note(is_drum=False, |
|
program=24, |
|
onset=obs.time, |
|
offset=obs.time + obs.duration, |
|
pitch=round(obs.value), |
|
velocity=1) |
|
notes.append(note) |
|
|
|
notes = sort_notes(notes) |
|
notes = validate_notes(notes) |
|
notes = trim_overlapping_notes(notes) |
|
|
|
return { |
|
'guitarset_id': id, |
|
'program': [24], |
|
'is_drum': [0], |
|
'duration_sec': jam.file_metadata.duration, |
|
'notes': notes, |
|
}, { |
|
'guitarset_id': id, |
|
'program': [24], |
|
'is_drum': [0], |
|
'duration_sec': jam.file_metadata.duration, |
|
'note_events': note2note_event(notes), |
|
} |
|
|
|
|
|
def generate_pitch_shifted_wav_and_midi(file_list: Dict, min_pitch_shift: int = -5, max_pitch_shift: int = 6): |
|
for key in file_list.keys(): |
|
midi_file = file_list[key]['midi_file'] |
|
audio_file = file_list[key]['mix_audio_file'] |
|
|
|
|
|
pitch_shift_midi(src_midi_file=midi_file, |
|
min_pitch_shift=min_pitch_shift, |
|
max_pitch_shift=max_pitch_shift, |
|
write_midi_file=True, |
|
write_notes_file=True, |
|
write_note_events_file=True) |
|
|
|
|
|
pitch_shift_audio(src_audio_file=audio_file, |
|
min_pitch_shift=min_pitch_shift, |
|
max_pitch_shift=max_pitch_shift, |
|
random_microshift_range=(-10, 11)) |
|
|
|
|
|
def preprocess_guitarset16k(data_home: os.PathLike, |
|
dataset_name: str = 'guitarset', |
|
pitch_shift_range: Optional[Tuple[int, int]] = (-5, 6)) -> None: |
|
""" |
|
Splits: |
|
- progression_1, progression_2, progression_3 |
|
- train, validation, test (by random selection [4,1,1] player) |
|
|
|
Writes: |
|
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: |
|
{ |
|
index: |
|
{ |
|
'guitarset_id': guitarset_id, |
|
'n_frames': (int), |
|
'mix_audio_file': 'path/to/mix.wav', |
|
'notes_file': 'path/to/notes.npy', |
|
'note_events_file': 'path/to/note_events.npy', |
|
'midi_file': 'path/to/midi.mid', |
|
'program': List[int], |
|
'is_drum': List[int], # 0 or 1 |
|
} |
|
} |
|
""" |
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
all_ann_files = glob.glob(os.path.join(base_dir, 'annotation/*.jams'), recursive=True) |
|
assert len(all_ann_files) == 360 |
|
notes_files = {} |
|
note_events_files = {} |
|
midi_files = {} |
|
for ann_file in all_ann_files: |
|
|
|
guitarset_id = os.path.basename(ann_file).split('.')[0] |
|
notes, note_events = create_note_event_and_note_from_jam(ann_file, guitarset_id) |
|
|
|
notes_file = ann_file.replace('.jams', '_notes.npy') |
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
print(f'Created {notes_file}') |
|
|
|
note_events_file = ann_file.replace('.jams', '_note_events.npy') |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
print(f'Created {note_events_file}') |
|
|
|
|
|
midi_file = ann_file.replace('.jams', '.mid') |
|
note_event2midi(note_events=note_events['note_events'], output_file=midi_file) |
|
print(f'Created {midi_file}') |
|
|
|
notes_files[guitarset_id] = notes_file |
|
note_events_files[guitarset_id] = note_events_file |
|
midi_files[guitarset_id] = midi_file |
|
|
|
|
|
pass |
|
|
|
|
|
guitarset_ids_by_split = { |
|
'progression_1': [], |
|
'progression_2': [], |
|
'progression_3': [], |
|
'player_0': [], |
|
'player_1': [], |
|
'player_2': [], |
|
'player_3': [], |
|
'player_4': [], |
|
'player_5': [], |
|
'train': [], |
|
'validation': [], |
|
'test': [], |
|
'all': [], |
|
} |
|
|
|
for ann_file in all_ann_files: |
|
guitarset_id = os.path.basename(ann_file).split('.')[0] |
|
progression = int(guitarset_id.split('_')[1].split('-')[0][-1]) |
|
player = int(guitarset_id.split('_')[0]) |
|
|
|
|
|
guitarset_ids_by_split['all'].append(guitarset_id) |
|
|
|
|
|
if progression == 1: |
|
guitarset_ids_by_split['progression_1'].append(guitarset_id) |
|
elif progression == 2: |
|
guitarset_ids_by_split['progression_2'].append(guitarset_id) |
|
elif progression == 3: |
|
guitarset_ids_by_split['progression_3'].append(guitarset_id) |
|
else: |
|
raise ValueError(f'Invalid progression: {guitarset_id}') |
|
|
|
|
|
if player == 0: |
|
guitarset_ids_by_split['player_0'].append(guitarset_id) |
|
elif player == 1: |
|
guitarset_ids_by_split['player_1'].append(guitarset_id) |
|
elif player == 2: |
|
guitarset_ids_by_split['player_2'].append(guitarset_id) |
|
elif player == 3: |
|
guitarset_ids_by_split['player_3'].append(guitarset_id) |
|
elif player == 4: |
|
guitarset_ids_by_split['player_4'].append(guitarset_id) |
|
elif player == 5: |
|
guitarset_ids_by_split['player_5'].append(guitarset_id) |
|
else: |
|
raise ValueError(f'Invalid player: {guitarset_id}') |
|
|
|
|
|
for key in guitarset_ids_by_split.keys(): |
|
guitarset_ids_by_split[key] = sorted(guitarset_ids_by_split[key]) |
|
for i in range(6): |
|
assert len(guitarset_ids_by_split[f'player_{i}']) == 60 |
|
|
|
|
|
for i in range(60): |
|
rand_sel = np.random.choice(6, size=6, replace=False) |
|
player_train = rand_sel[:4] |
|
player_valid = rand_sel[4] |
|
player_test = rand_sel[5] |
|
for player in player_train: |
|
guitarset_ids_by_split['train'].append(guitarset_ids_by_split[f'player_{player}'][i]) |
|
guitarset_ids_by_split['validation'].append(guitarset_ids_by_split[f'player_{player_valid}'][i]) |
|
guitarset_ids_by_split['test'].append(guitarset_ids_by_split[f'player_{player_test}'][i]) |
|
|
|
assert len(guitarset_ids_by_split['train']) == 240 |
|
assert len(guitarset_ids_by_split['validation']) == 60 |
|
assert len(guitarset_ids_by_split['test']) == 60 |
|
|
|
|
|
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'validation', 'test', 'all']: |
|
file_list = {} |
|
for i, gid in enumerate(guitarset_ids_by_split[split]): |
|
|
|
wav_file = {} |
|
wav_file['hex'] = os.path.join(base_dir, 'audio_hex-pickup_original', gid + '_' + 'hex' + '.wav') |
|
wav_file['hex_cln'] = os.path.join(base_dir, 'audio_hex-pickup_debleeded', gid + '_' + 'hex_cln' + '.wav') |
|
wav_file['mic'] = os.path.join(base_dir, 'audio_mono-mic', gid + '_' + 'mic' + '.wav') |
|
wav_file['mix'] = os.path.join(base_dir, 'audio_mono-pickup_mix', gid + '_' + 'mix' + '.wav') |
|
for ver in wav_file: |
|
assert os.path.exists(wav_file[ver]) |
|
|
|
for ver in ['mic', 'mix']: |
|
file_list[i, ver] = { |
|
'guitarset_id': gid + '_' + ver, |
|
'n_frames': get_audio_file_info(wav_file[ver])[1], |
|
'mix_audio_file': wav_file[ver], |
|
'notes_file': notes_files[gid], |
|
'note_events_file': note_events_files[gid], |
|
'midi_file': midi_files[gid], |
|
'program': [24], |
|
'is_drum': [0], |
|
} |
|
|
|
|
|
_file_list = {} |
|
for i, v in enumerate(file_list.values()): |
|
_file_list[i] = v |
|
file_list = _file_list |
|
|
|
|
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|
|
if pitch_shift_range == None: |
|
return |
|
else: |
|
min_pitch_shift, max_pitch_shift = pitch_shift_range |
|
|
|
|
|
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') |
|
with open(file_list_all_path, 'r') as f: |
|
fl = json.load(f) |
|
file_list_all = {int(key): value for key, value in fl.items()} |
|
generate_pitch_shifted_wav_and_midi(file_list_all, min_pitch_shift=min_pitch_shift, max_pitch_shift=max_pitch_shift) |
|
|
|
|
|
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'all']: |
|
src_file_list_path = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(src_file_list_path, 'r') as f: |
|
fl = json.load(f) |
|
src_file_list = {int(key): value for key, value in fl.items()} |
|
|
|
file_list = {} |
|
for k, v in src_file_list.items(): |
|
for pitch_shift in range(min_pitch_shift, max_pitch_shift): |
|
if pitch_shift == 0: |
|
file_list[k, 0] = copy.deepcopy(v) |
|
else: |
|
file_list[k, pitch_shift] = copy.deepcopy(v) |
|
shifted_audio_file = v['mix_audio_file'].replace('.wav', f'_pshift{pitch_shift}.wav') |
|
assert os.path.isfile(shifted_audio_file) == True |
|
file_list[k, pitch_shift]['mix_audio_file'] = shifted_audio_file |
|
file_list[k, pitch_shift]['n_frames'] = get_audio_file_info(shifted_audio_file)[1] |
|
file_list[k, pitch_shift]['pitch_shift'] = pitch_shift |
|
|
|
shifted_midi_file = v['midi_file'].replace('.mid', f'_pshift{pitch_shift}.mid') |
|
shifted_notes_file = v['notes_file'].replace('_notes', f'_pshift{pitch_shift}_notes') |
|
shifted_note_events_file = v['note_events_file'].replace('_note_events', |
|
f'_pshift{pitch_shift}_note_events') |
|
assert os.path.isfile(shifted_midi_file) == True |
|
assert os.path.isfile(shifted_notes_file) == True |
|
assert os.path.isfile(shifted_note_events_file) == True |
|
file_list[k, pitch_shift]['midi_file'] = shifted_midi_file |
|
file_list[k, pitch_shift]['notes_file'] = shifted_notes_file |
|
file_list[k, pitch_shift]['note_events_file'] = shifted_note_events_file |
|
assert len(file_list) == len(src_file_list) * (max_pitch_shift - min_pitch_shift) |
|
|
|
|
|
_file_list = {} |
|
for i, v in enumerate(file_list.values()): |
|
_file_list[i] = v |
|
file_list = _file_list |
|
|
|
|
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_pshift_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|
|
|
|
def create_filelist_by_style_guitarset16k(data_home: os.PathLike, dataset_name: str = 'guitarset') -> None: |
|
|
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
file_list_pshift_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_pshift_file_list.json') |
|
with open(file_list_pshift_all_path, 'r') as f: |
|
fl_pshift = json.load(f) |
|
assert len(fl_pshift) == 7920 |
|
|
|
|
|
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') |
|
with open(file_list_all_path, 'r') as f: |
|
fl = json.load(f) |
|
assert len(fl) == 720 |
|
|
|
|
|
styles = ['BN', 'Funk', 'SS', 'Jazz', 'Rock'] |
|
for style in styles: |
|
|
|
train_file_list = {} |
|
i = 0 |
|
for v in fl_pshift.values(): |
|
if style in v['guitarset_id']: |
|
train_file_list[i] = copy.deepcopy(v) |
|
i += 1 |
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_pshift_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(train_file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|
|
test_file_list = {} |
|
i = 0 |
|
for v in fl.values(): |
|
if style in v['guitarset_id']: |
|
test_file_list[i] = copy.deepcopy(v) |
|
i += 1 |
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(test_file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|