YourMT3 / amt /src /utils /preprocess /preprocess_guitarset.py
mimbres's picture
.
a03c9b4
raw
history blame
17.3 kB
""" preprocess_guitarset.py """
import os
import glob
import copy
import json
from typing import Dict, List, Tuple, Optional
import numpy as np
import jams
from utils.note_event_dataclasses import Note, NoteEvent
from utils.audio import get_audio_file_info, pitch_shift_audio
from utils.midi import note_event2midi, pitch_shift_midi
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes
def create_note_event_and_note_from_jam(jam_file: str, id: str) -> Tuple[Dict, Dict]:
jam = jams.load(jam_file)
notes = []
for ann in jam.annotations:
for obs in ann.data:
if isinstance(obs.value, float):
if obs.confidence == None:
note = Note(is_drum=False,
program=24,
onset=obs.time,
offset=obs.time + obs.duration,
pitch=round(obs.value),
velocity=1)
notes.append(note)
# Sort, validate, and trim notes
notes = sort_notes(notes)
notes = validate_notes(notes)
notes = trim_overlapping_notes(notes)
return { # notes
'guitarset_id': id,
'program': [24],
'is_drum': [0],
'duration_sec': jam.file_metadata.duration,
'notes': notes,
}, { # note_events
'guitarset_id': id,
'program': [24],
'is_drum': [0],
'duration_sec': jam.file_metadata.duration,
'note_events': note2note_event(notes),
}
def generate_pitch_shifted_wav_and_midi(file_list: Dict, min_pitch_shift: int = -5, max_pitch_shift: int = 6):
for key in file_list.keys():
midi_file = file_list[key]['midi_file']
audio_file = file_list[key]['mix_audio_file']
# Write midi, notes, and note_events with pitch shift
pitch_shift_midi(src_midi_file=midi_file,
min_pitch_shift=min_pitch_shift,
max_pitch_shift=max_pitch_shift,
write_midi_file=True,
write_notes_file=True,
write_note_events_file=True)
# Write wav with pitch shift
pitch_shift_audio(src_audio_file=audio_file,
min_pitch_shift=min_pitch_shift,
max_pitch_shift=max_pitch_shift,
random_microshift_range=(-10, 11))
def preprocess_guitarset16k(data_home: os.PathLike,
dataset_name: str = 'guitarset',
pitch_shift_range: Optional[Tuple[int, int]] = (-5, 6)) -> None:
"""
Splits:
- progression_1, progression_2, progression_3
- train, validation, test (by random selection [4,1,1] player)
Writes:
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys:
{
index:
{
'guitarset_id': guitarset_id,
'n_frames': (int),
'mix_audio_file': 'path/to/mix.wav',
'notes_file': 'path/to/notes.npy',
'note_events_file': 'path/to/note_events.npy',
'midi_file': 'path/to/midi.mid',
'program': List[int],
'is_drum': List[int], # 0 or 1
}
}
"""
# Directory and file paths
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
output_index_dir = os.path.join(data_home, 'yourmt3_indexes')
os.makedirs(output_index_dir, exist_ok=True)
# Process annotations
all_ann_files = glob.glob(os.path.join(base_dir, 'annotation/*.jams'), recursive=True)
assert len(all_ann_files) == 360
notes_files = {}
note_events_files = {}
midi_files = {}
for ann_file in all_ann_files:
# Convert all annotations to notes and note events
guitarset_id = os.path.basename(ann_file).split('.')[0]
notes, note_events = create_note_event_and_note_from_jam(ann_file, guitarset_id)
notes_file = ann_file.replace('.jams', '_notes.npy')
np.save(notes_file, notes, allow_pickle=True, fix_imports=False)
print(f'Created {notes_file}')
note_events_file = ann_file.replace('.jams', '_note_events.npy')
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False)
print(f'Created {note_events_file}')
# Create a midi file from the note_events
midi_file = ann_file.replace('.jams', '.mid')
note_event2midi(note_events=note_events['note_events'], output_file=midi_file)
print(f'Created {midi_file}')
notes_files[guitarset_id] = notes_file
note_events_files[guitarset_id] = note_events_file
midi_files[guitarset_id] = midi_file
# Process audio files
pass
# Create file_list.json
guitarset_ids_by_split = {
'progression_1': [],
'progression_2': [],
'progression_3': [],
'player_0': [],
'player_1': [],
'player_2': [],
'player_3': [],
'player_4': [],
'player_5': [],
'train': [], # random selection of 4 players for each style
'validation': [], # random selection of 1 player for each style
'test': [], # random selection of 1 player for each style
'all': [],
}
# by progressions, players and all
for ann_file in all_ann_files:
guitarset_id = os.path.basename(ann_file).split('.')[0]
progression = int(guitarset_id.split('_')[1].split('-')[0][-1])
player = int(guitarset_id.split('_')[0])
# all
guitarset_ids_by_split['all'].append(guitarset_id)
# progression
if progression == 1:
guitarset_ids_by_split['progression_1'].append(guitarset_id)
elif progression == 2:
guitarset_ids_by_split['progression_2'].append(guitarset_id)
elif progression == 3:
guitarset_ids_by_split['progression_3'].append(guitarset_id)
else:
raise ValueError(f'Invalid progression: {guitarset_id}')
# player
if player == 0:
guitarset_ids_by_split['player_0'].append(guitarset_id)
elif player == 1:
guitarset_ids_by_split['player_1'].append(guitarset_id)
elif player == 2:
guitarset_ids_by_split['player_2'].append(guitarset_id)
elif player == 3:
guitarset_ids_by_split['player_3'].append(guitarset_id)
elif player == 4:
guitarset_ids_by_split['player_4'].append(guitarset_id)
elif player == 5:
guitarset_ids_by_split['player_5'].append(guitarset_id)
else:
raise ValueError(f'Invalid player: {guitarset_id}')
# sort
for key in guitarset_ids_by_split.keys():
guitarset_ids_by_split[key] = sorted(guitarset_ids_by_split[key])
for i in range(6):
assert len(guitarset_ids_by_split[f'player_{i}']) == 60
# train/valid/test by random player
for i in range(60):
rand_sel = np.random.choice(6, size=6, replace=False)
player_train = rand_sel[:4]
player_valid = rand_sel[4]
player_test = rand_sel[5]
for player in player_train:
guitarset_ids_by_split['train'].append(guitarset_ids_by_split[f'player_{player}'][i])
guitarset_ids_by_split['validation'].append(guitarset_ids_by_split[f'player_{player_valid}'][i])
guitarset_ids_by_split['test'].append(guitarset_ids_by_split[f'player_{player_test}'][i])
assert len(guitarset_ids_by_split['train']) == 240
assert len(guitarset_ids_by_split['validation']) == 60
assert len(guitarset_ids_by_split['test']) == 60
# Create file_list.json
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'validation', 'test', 'all']:
file_list = {}
for i, gid in enumerate(guitarset_ids_by_split[split]):
# Check if wav files exist for the 4 versions
wav_file = {}
wav_file['hex'] = os.path.join(base_dir, 'audio_hex-pickup_original', gid + '_' + 'hex' + '.wav')
wav_file['hex_cln'] = os.path.join(base_dir, 'audio_hex-pickup_debleeded', gid + '_' + 'hex_cln' + '.wav')
wav_file['mic'] = os.path.join(base_dir, 'audio_mono-mic', gid + '_' + 'mic' + '.wav')
wav_file['mix'] = os.path.join(base_dir, 'audio_mono-pickup_mix', gid + '_' + 'mix' + '.wav')
for ver in wav_file:
assert os.path.exists(wav_file[ver])
for ver in ['mic', 'mix']: #'hex', 'hex_cln',
file_list[i, ver] = {
'guitarset_id': gid + '_' + ver,
'n_frames': get_audio_file_info(wav_file[ver])[1],
'mix_audio_file': wav_file[ver],
'notes_file': notes_files[gid],
'note_events_file': note_events_files[gid],
'midi_file': midi_files[gid],
'program': [24],
'is_drum': [0],
}
# Reindexing file_list
_file_list = {}
for i, v in enumerate(file_list.values()):
_file_list[i] = v
file_list = _file_list
# Write json
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json')
with open(output_file, 'w') as f:
json.dump(file_list, f, indent=4)
print(f'Created {output_file}')
if pitch_shift_range == None:
return
else:
min_pitch_shift, max_pitch_shift = pitch_shift_range
# Generate pitch shifted wav and MIDI
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json')
with open(file_list_all_path, 'r') as f:
fl = json.load(f)
file_list_all = {int(key): value for key, value in fl.items()}
generate_pitch_shifted_wav_and_midi(file_list_all, min_pitch_shift=min_pitch_shift, max_pitch_shift=max_pitch_shift)
# Create file_list.json for pitch shifted data
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'all']:
src_file_list_path = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json')
with open(src_file_list_path, 'r') as f:
fl = json.load(f)
src_file_list = {int(key): value for key, value in fl.items()}
file_list = {}
for k, v in src_file_list.items():
for pitch_shift in range(min_pitch_shift, max_pitch_shift):
if pitch_shift == 0:
file_list[k, 0] = copy.deepcopy(v)
else:
file_list[k, pitch_shift] = copy.deepcopy(v)
shifted_audio_file = v['mix_audio_file'].replace('.wav', f'_pshift{pitch_shift}.wav')
assert os.path.isfile(shifted_audio_file) == True
file_list[k, pitch_shift]['mix_audio_file'] = shifted_audio_file
file_list[k, pitch_shift]['n_frames'] = get_audio_file_info(shifted_audio_file)[1]
file_list[k, pitch_shift]['pitch_shift'] = pitch_shift
shifted_midi_file = v['midi_file'].replace('.mid', f'_pshift{pitch_shift}.mid')
shifted_notes_file = v['notes_file'].replace('_notes', f'_pshift{pitch_shift}_notes')
shifted_note_events_file = v['note_events_file'].replace('_note_events',
f'_pshift{pitch_shift}_note_events')
assert os.path.isfile(shifted_midi_file) == True
assert os.path.isfile(shifted_notes_file) == True
assert os.path.isfile(shifted_note_events_file) == True
file_list[k, pitch_shift]['midi_file'] = shifted_midi_file
file_list[k, pitch_shift]['notes_file'] = shifted_notes_file
file_list[k, pitch_shift]['note_events_file'] = shifted_note_events_file
assert len(file_list) == len(src_file_list) * (max_pitch_shift - min_pitch_shift)
# Reindexing file_list
_file_list = {}
for i, v in enumerate(file_list.values()):
_file_list[i] = v
file_list = _file_list
# Write json
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_pshift_file_list.json')
with open(output_file, 'w') as f:
json.dump(file_list, f, indent=4)
print(f'Created {output_file}')
def create_filelist_by_style_guitarset16k(data_home: os.PathLike, dataset_name: str = 'guitarset') -> None:
# Directory and file paths
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
output_index_dir = os.path.join(data_home, 'yourmt3_indexes')
os.makedirs(output_index_dir, exist_ok=True)
# Load filelist, pshift_all for train
file_list_pshift_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_pshift_file_list.json')
with open(file_list_pshift_all_path, 'r') as f:
fl_pshift = json.load(f)
assert len(fl_pshift) == 7920
# Load filelist, all for test
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json')
with open(file_list_all_path, 'r') as f:
fl = json.load(f)
assert len(fl) == 720
# Create file_list.json for training each style using pitch shifted data
styles = ['BN', 'Funk', 'SS', 'Jazz', 'Rock']
for style in styles:
# Create and write pshift file list
train_file_list = {}
i = 0
for v in fl_pshift.values():
if style in v['guitarset_id']:
train_file_list[i] = copy.deepcopy(v)
i += 1
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_pshift_file_list.json')
with open(output_file, 'w') as f:
json.dump(train_file_list, f, indent=4)
print(f'Created {output_file}')
test_file_list = {}
i = 0
for v in fl.values():
if style in v['guitarset_id']:
test_file_list[i] = copy.deepcopy(v)
i += 1
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_file_list.json')
with open(output_file, 'w') as f:
json.dump(test_file_list, f, indent=4)
print(f'Created {output_file}')
# BASIC_PITCH_VALIDATION_IDS = [
# "05_Funk2-108-Eb_comp", "04_BN2-131-B_comp", "04_Jazz3-150-C_solo", "05_Rock2-85-F_solo",
# "05_Funk3-98-A_comp", "05_BN3-119-G_comp", "02_SS2-107-Ab_solo", "01_BN2-131-B_solo",
# "00_BN2-166-Ab_comp", "04_SS1-100-C#_solo", "01_BN2-166-Ab_solo", "01_Rock1-130-A_solo",
# "04_Funk2-119-G_solo", "01_SS2-107-Ab_comp", "05_Funk3-98-A_solo", "05_Funk1-114-Ab_comp",
# "05_Jazz2-187-F#_solo", "05_SS1-100-C#_comp", "00_Rock3-148-C_solo", "02_Rock3-117-Bb_comp",
# "01_BN1-147-Gb_solo", "01_Rock1-90-C#_solo", "01_SS2-107-Ab_solo", "02_Jazz3-150-C_solo",
# "00_Funk1-97-C_solo", "05_SS3-98-C_solo", "03_Rock3-148-C_comp", "03_Rock3-117-Bb_solo",
# "04_Jazz2-187-F#_solo", "05_Jazz2-187-F#_comp", "02_SS1-68-E_solo", "04_SS2-88-F_solo",
# "04_BN2-131-B_solo", "04_Jazz3-137-Eb_comp", "00_SS2-107-Ab_comp", "01_Rock1-130-A_comp",
# "00_Jazz1-130-D_comp", "04_Funk2-108-Eb_comp", "05_BN2-166-Ab_comp"
# ]
# BASIC_PITCH_TEST_IDS = [
# "04_SS3-84-Bb_solo", "02_Funk1-114-Ab_solo", "05_Funk1-114-Ab_solo", "05_Funk1-97-C_solo",
# "00_Rock3-148-C_comp", "00_Jazz3-137-Eb_comp", "00_Jazz1-200-B_comp", "03_SS3-98-C_solo",
# "05_Jazz1-130-D_comp", "00_Jazz2-110-Bb_comp", "02_Funk3-98-A_comp", "04_Rock1-130-A_comp",
# "03_BN1-129-Eb_comp", "03_Funk2-119-G_comp", "05_BN1-147-Gb_comp", "02_Rock1-90-C#_comp",
# "00_Funk3-98-A_solo", "01_SS1-100-C#_comp", "00_Funk3-98-A_comp", "02_BN3-154-E_comp",
# "01_Jazz3-137-Eb_comp", "00_BN2-131-B_comp", "04_SS1-68-E_solo", "05_Funk1-97-C_comp",
# "04_Jazz3-137-Eb_solo", "05_Rock2-142-D_solo", "02_BN3-119-G_solo", "02_Rock2-142-D_solo",
# "01_BN1-129-Eb_solo", "00_Rock2-85-F_comp", "00_Rock1-130-A_solo"
# ]
# def create_filelist_for_basic_pitch_benchmark_guitarset16k(data_home: os.PathLike,
# dataset_name: str = 'guitarset') -> None:
# # Directory and file paths
# base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
# output_index_dir = os.path.join(data_home, 'yourmt3_indexes')
# os.makedirs(output_index_dir, exist_ok=True)
# # Load filelist, pshift_all for train
# file_list_pshift_all_path = os.path.join(output_index_dir,
# f'{dataset_name}_all_pshift_file_list.json')
# with open(file_list_pshift_all_path, 'r') as f:
# fl_pshift = json.load(f)
# assert len(fl_pshift) == 7920
# # Load filelist, all without pshift
# file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json')
# with open(file_list_all_path, 'r') as f:
# fl = json.load(f)
# assert len(fl) == 720
# # This is abandoned, because the split is not official one.