Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| """preprocess_egmd.py""" | |
| import os | |
| import csv | |
| import glob | |
| import re | |
| import json | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| from utils.audio import get_audio_file_info | |
| from utils.midi import midi2note, note_event2midi | |
| from utils.note2event import note2note_event, note_event2event | |
| from utils.event2note import event2note_event | |
| from utils.note_event_dataclasses import Note, NoteEvent | |
| from utils.utils import note_event2token2note_event_sanity_check | |
| # from utils.utils import assert_note_events_almost_equal | |
| def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]: | |
| """Extracts note or note_event and metadata from midi: | |
| Returns: | |
| notes (dict): note events and metadata. | |
| note_events (dict): note events and metadata. | |
| """ | |
| notes, dur_sec = midi2note( | |
| mid_file, | |
| binary_velocity=True, | |
| ch_9_as_drum=True, | |
| force_all_drum=True, | |
| trim_overlap=True, | |
| fix_offset=True, | |
| quantize=True, | |
| verbose=0, | |
| minimum_offset_sec=0.01, | |
| drum_offset_sec=0.01, | |
| ignore_pedal=True) | |
| return { # notes | |
| 'egmd_id': id, | |
| 'program': [128], | |
| 'is_drum': [1], | |
| 'duration_sec': dur_sec, | |
| 'notes': notes, | |
| }, { # note_events | |
| 'maps_id': id, | |
| 'program': [128], | |
| 'is_drum': [1], | |
| 'duration_sec': dur_sec, | |
| 'note_events': note2note_event(notes), | |
| } | |
| def preprocess_egmd16k(data_home: os.PathLike, dataset_name='egmd') -> None: | |
| """ | |
| Splits: | |
| - train: 35217 files | |
| - validation: 5031 files | |
| - test: 5289 files | |
| - test_reduced: 246 files that contain '_5.midi' or '_10.midi' in the filename | |
| Writes: | |
| - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
| { | |
| index: | |
| { | |
| 'egmd_id': egmd_id, # filename wihout extension | |
| 'n_frames': (int), | |
| 'mix_audio_file': 'path/to/mix.wav', | |
| 'notes_file': 'path/to/notes.npy', | |
| 'note_events_file': 'path/to/note_events.npy', | |
| 'midi_file': 'path/to/midi.mid', | |
| 'program': List[int], | |
| 'is_drum': List[int], # 0 or 1 | |
| } | |
| } | |
| """ | |
| # Directory and file paths | |
| base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
| output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
| os.makedirs(output_index_dir, exist_ok=True) | |
| # Load csv file and create a dictionary | |
| csv_file = os.path.join(base_dir, 'e-gmd-v1.0.0.csv') | |
| with open(csv_file, 'r') as f: | |
| csv_dict_reader = csv.DictReader(f) | |
| egmd_dict_list_all = list(csv_dict_reader) | |
| assert len(egmd_dict_list_all) == 45537 | |
| # Process MIDI files | |
| for d in egmd_dict_list_all: | |
| emgd_id = d['midi_filename'].split('.')[0] | |
| midi_file = os.path.join(base_dir, d['midi_filename']) | |
| notes, note_events = create_note_event_and_note_from_midi(midi_file, emgd_id) | |
| # Write notes and note_events | |
| notes_file = midi_file.replace('.midi', '_notes.npy') | |
| note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
| np.save(notes_file, notes, allow_pickle=True, fix_imports=False) | |
| print(f"Created {notes_file}") | |
| np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) | |
| print(f"Created {note_events_file}") | |
| # rewrite 120 bpm quantized midi file | |
| quantized_midi_file = midi_file.replace('.midi', '_quantized_120bpm.mid') | |
| note_event2midi(note_events['note_events'], quantized_midi_file) | |
| print(f'Wrote {quantized_midi_file}') | |
| # Process audio files | |
| pass | |
| # Create index files | |
| for split in ['train', 'validation', 'test']: | |
| file_list = {} | |
| i = 0 | |
| for d in egmd_dict_list_all: | |
| if d['split'] == split: | |
| egmd_id = d['midi_filename'].split('.')[0] | |
| mix_audio_file = os.path.join(base_dir, d['audio_filename']) | |
| n_frames = get_audio_file_info(mix_audio_file)[1] | |
| midi_file = os.path.join(base_dir, d['midi_filename']) | |
| notes_file = midi_file.replace('.midi', '_notes.npy') | |
| note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
| # check file existence | |
| assert os.path.exists(mix_audio_file) | |
| assert os.path.exists(midi_file) | |
| assert os.path.exists(notes_file) | |
| assert os.path.exists(note_events_file) | |
| # create file list | |
| file_list[i] = { | |
| 'egmd_id': egmd_id, | |
| 'n_frames': n_frames, | |
| 'mix_audio_file': mix_audio_file, | |
| 'notes_file': notes_file, | |
| 'note_events_file': note_events_file, | |
| 'midi_file': midi_file, | |
| 'program': [128], | |
| 'is_drum': [1], | |
| } | |
| i += 1 | |
| else: | |
| pass | |
| # Write file list | |
| output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
| with open(output_file, 'w') as f: | |
| json.dump(file_list, f, indent=4) | |
| print(f'Wrote {output_file}') | |
| if split == 'train': | |
| assert len(file_list) == 35217 | |
| elif split == 'validation': | |
| assert len(file_list) == 5031 | |
| elif split == 'test': | |
| assert len(file_list) == 5289 | |
| # Create reduced test index file | |
| split = 'test_reduced' | |
| file_list = {} | |
| i = 0 | |
| for d in egmd_dict_list_all: | |
| if d['split'] == 'test': | |
| midi_file = os.path.join(base_dir, d['midi_filename']) | |
| if '_5.midi' in midi_file or '_10.midi' in midi_file: | |
| egmd_id = d['midi_filename'].split('.')[0] | |
| mix_audio_file = os.path.join(base_dir, d['audio_filename']) | |
| n_frames = get_audio_file_info(mix_audio_file)[1] | |
| notes_file = midi_file.replace('.midi', '_notes.npy') | |
| note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
| file_list[i] = { | |
| 'egmd_id': egmd_id, | |
| 'n_frames': n_frames, | |
| 'mix_audio_file': mix_audio_file, | |
| 'notes_file': notes_file, | |
| 'note_events_file': note_events_file, | |
| 'midi_file': midi_file, | |
| 'program': [128], | |
| 'is_drum': [1], | |
| } | |
| i += 1 | |
| output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
| with open(output_file, 'w') as f: | |
| json.dump(file_list, f, indent=4) | |
| print(f'Wrote {output_file}') | |
| assert len(file_list) == 246 | |
