Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
""" note_event_roundtrip_test.py: | |
This file contains tests for the round trip conversion between Note and | |
NoteEvent and Event. | |
Itinerary 1: | |
NoteEvent β Event β Token β Event β NoteEvent | |
Itinerary 2: | |
Note β NoteEvent β Event β Token β Event β NoteEvent β Note | |
Training: | |
(Dataloader) NoteEvent β (augmentation) β Event β Token | |
Evaluation : | |
(Model side) Token β Event β NoteEvent β Note β (mir_eval) | |
(Ground Truth) Note β (mir_eval) | |
β’ This conversion may fail for unsorted and unquantized timing events. | |
β’ Acitivity attribute of NoteEvent is often ignorable. | |
""" | |
import unittest | |
import numpy as np | |
from assert_fns import assert_notes_almost_equal | |
from assert_fns import assert_note_events_almost_equal | |
from assert_fns import assert_track_metrics_score1 | |
from utils.note_event_dataclasses import Note, NoteEvent, Event | |
from utils.note2event import note2note_event, note_event2event | |
from utils.note2event import validate_notes, trim_overlapping_notes | |
from utils.event2note import event2note_event, note_event2note | |
from utils.tokenizer import EventTokenizer, NoteEventTokenizer | |
from utils.midi import note_event2midi | |
from utils.midi import midi2note | |
from utils.note2event import slice_multiple_note_events_and_ties_to_bundle | |
from utils.event2note import merge_zipped_note_events_and_ties_to_notes | |
from utils.metrics import compute_track_metrics | |
from config.vocabulary import GM_INSTR_FULL, SINGING_SOLO_CLASS | |
# yapf: disable | |
class TestNoteEventRoundTrip1(unittest.TestCase): | |
def setUp(self) -> None: | |
self.note_events = [ | |
NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), | |
NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), | |
NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()), | |
NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), | |
NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity=set()), | |
NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity=set()), | |
NoteEvent(is_drum=True, program=128, time=2.0, velocity=1, pitch=38, activity=set()), | |
NoteEvent(is_drum=False, program=33, time=2.0, velocity=0, pitch=62, activity=set()) | |
] | |
self.tokenizer = EventTokenizer() | |
def test_note_event_rt_ne2e2ne(self): | |
""" NoteEvent β Event β NoteEvent """ | |
note_events = self.note_events.copy() | |
events = note_event2event(note_events=note_events, | |
tie_note_events=None, | |
start_time=0, sort=True) | |
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( | |
events, start_time=0, sort=True, tps=100) | |
self.assertSequenceEqual(note_events, recon_note_events) | |
self.assertEqual(len(err_cnt), 0) | |
def test_note_event_rt_ne2e2t2e2ne(self): | |
""" NoteEvent β Event β Token β Event β NoteEvent """ | |
note_events = self.note_events.copy() | |
events = note_event2event( | |
note_events=note_events, tie_note_events=None, start_time=0, sort=True) | |
tokens = self.tokenizer.encode(events) | |
events = self.tokenizer.decode(tokens) | |
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( | |
events, start_time=0, sort=True, tps=100) | |
self.assertSequenceEqual(note_events, recon_note_events) | |
self.assertEqual(len(err_cnt), 0) | |
class TestNoteEvent2(unittest.TestCase): | |
def setUp(self) -> None: | |
notes = [ | |
Note(is_drum=False, program=33, onset=0, offset=1.5, pitch=60, velocity=1), | |
Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), | |
Note(is_drum=False, program=25, onset=0.4, offset=1.1, pitch=55, velocity=1), | |
Note(is_drum=True, program=128, onset=1, offset=1.01, pitch=42, velocity=1), | |
Note(is_drum=False, program=33, onset=1.2, offset=1.8, pitch=80, velocity=1), | |
Note(is_drum=False, program=33, onset=1.6, offset=2.0, pitch=62, velocity=1), | |
Note(is_drum=False, program=100, onset=1.6, offset=2.0, pitch=77, velocity=1), | |
Note(is_drum=False, program=98, onset=1.7, offset=2.0, pitch=77, velocity=1), | |
Note(is_drum=True, program=128, onset=1.9, offset=1.91, pitch=38, velocity=1) | |
] | |
# Validate and trim notes to make sure they are valid. | |
_notes = validate_notes(notes, fix=True) | |
self.assertSequenceEqual(notes, _notes) | |
_notes = trim_overlapping_notes(notes, sort=True) | |
self.assertSequenceEqual(notes, _notes) | |
self.notes = notes | |
self.tokenizer = EventTokenizer() | |
def test_note_event_rt_n2ne2e2t2e2ne2n(self): | |
""" Note β NoteEvent β Event β Token β Event β NoteEvent β Note """ | |
notes = self.notes.copy() | |
note_events = note2note_event(notes=notes, sort=True) | |
events = note_event2event(note_events=note_events, | |
tie_note_events=None, | |
start_time=0, | |
tps=100, | |
sort=True) | |
tokens = self.tokenizer.encode(events) | |
events = self.tokenizer.decode(tokens) | |
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( | |
events, start_time=0, sort=True, tps=100) | |
self.assertEqual(len(err_cnt), 0) | |
recon_notes, err_cnt = note_event2note(note_events=recon_note_events, sort=True) | |
self.assertEqual(len(err_cnt), 0) | |
assert_notes_almost_equal(notes, recon_notes, delta=5e-3) # 5 ms on/offset tolerance | |
# def test_encoding_from_midi_without_slicing_zz(self): | |
# """ MIDI β Note β NoteEvent β Event β Token β Event β NoteEvent β Note β MIDI """ | |
# src_midi_file = 'extras/examples/1727.mid' | |
# notes, _ = midi2note(src_midi_file, quantize=False) | |
# note_events = note2note_event(notes=notes, sort=True) | |
# events = note_event2event(note_events=note_events, | |
# tie_note_events=None, | |
# start_time=0, | |
# tps=100, | |
# sort=True) | |
# # check acculuated time by all the shift events | |
# last_shift = 0 | |
# for ev in events: | |
# if ev.type == "shift": | |
# last_shift = ev.value | |
# last_shift_in_sec = last_shift / 100 # 447.04 | |
# assert last_shift_in_sec == 447.04 | |
# # compare with the last offset time) | |
# last_offset_time = 0. | |
# for n in notes: | |
# if last_offset_time < n.offset: | |
# last_offset_time = n.offset # 447.0395833... | |
# self.assertAlmostEqual(last_shift_in_sec, last_offset_time, delta=1e-3) | |
# tokens = self.tokenizer.encode(events) | |
# # reconustrction ----------------------------------------------------------- | |
# recon_events = self.tokenizer.decode(tokens) | |
# self.assertSequenceEqual(events, recon_events) | |
# recon_note_events, unused_tie_note_events, err_cnt = event2note_event(recon_events) | |
# self.assertEqual(len(err_cnt), 0) | |
# assert_note_events_almost_equal(note_events, recon_note_events) | |
# recon_notes, err_cnt = note_event2note(note_events=recon_note_events, sort=True, fix_offset=False) | |
# self.assertEqual(len(err_cnt), 0) | |
# assert_notes_almost_equal(notes, recon_notes, delta=5e-3) | |
# # evaluation without MIDI | |
# drum_metric, non_drum_metric, instr_metric = compute_track_metrics(recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_threshold=0.5) | |
# assert_track_metrics_score1(drum_metric) | |
# assert_track_metrics_score1(non_drum_metric) | |
# assert_track_metrics_score1(instr_metric) | |
# # evaluation thourgh MIDI | |
# note_event2midi(recon_note_events, output_file='extras/examples/recon_1727.mid') | |
# re_recon_notes, _ = midi2note('extras/examples/recon_1727.mid', quantize=False) | |
# drum_metric, non_drum_metric, instr_metric = compute_track_metrics(re_recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_threshold=0.5) | |
# assert_track_metrics_score1(drum_metric) | |
# assert_track_metrics_score1(non_drum_metric) | |
# assert_track_metrics_score1(instr_metric) | |
def test_encoding_from_midi_with_slicing_zz(self): | |
src_midi_file = 'extras/examples/2106.mid' # 'extras/examples/1727.mid'# 'extras/examples/1733.mid' # these are from musicnet_em | |
notes, max_time = midi2note(src_midi_file, quantize=False) | |
note_events = note2note_event(notes=notes, sort=True) | |
# slice note events | |
num_segs = int(max_time * 16000 // 32757 + 1) | |
seg_len_sec = 32767 / 16000 | |
start_times = [i * seg_len_sec for i in range(num_segs)] | |
note_event_segments = slice_multiple_note_events_and_ties_to_bundle( | |
note_events, | |
start_times, | |
seg_len_sec, | |
) | |
# encode | |
tokenizer = NoteEventTokenizer() | |
token_array = np.zeros((num_segs, 1024), dtype=np.int32) | |
for i, tup in enumerate(list(zip(*note_event_segments.values()))): | |
padded_tokens = tokenizer.encode_plus(*tup) | |
token_array[i, :] = padded_tokens | |
# decode: warning: Invalid pitch event without program or velocity --> solved | |
zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches( | |
[token_array], start_times, return_events=True) | |
self.assertEqual(len(err_cnt), 0) | |
# First check, the number of empty note_events and tie_note_events | |
cnt_org_empty = 0 | |
cnt_recon_empty = 0 | |
for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): | |
org_note_events = note_event_segments['note_events'][i] | |
org_tie_note_events = note_event_segments['tie_note_events'][i] | |
if org_note_events == []: | |
cnt_org_empty += 1 | |
if recon_note_events == []: | |
cnt_recon_empty += 1 | |
assert len(org_note_events) == len(recon_note_events) # passed after bug fix | |
# self.assertEqual(len(org_tie_note_events), len(recon_tie_note_events)) | |
# Check the reconstruction of note_events | |
for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): | |
org_note_events = note_event_segments['note_events'][i] | |
org_tie_note_events = note_event_segments['tie_note_events'][i] | |
org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) | |
org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) | |
recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) | |
recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) | |
assert_note_events_almost_equal(org_note_events, recon_note_events) | |
assert_note_events_almost_equal(org_tie_note_events, recon_tie_note_events, ignore_time=True) | |
# Check notes | |
recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) | |
self.assertEqual(len(err_cnt), 0) | |
assert_notes_almost_equal(notes, recon_notes, delta=5.1e-3) | |
# Check metric | |
drum_metric, non_drum_metric, instr_metric = compute_track_metrics( | |
recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_tolerance=0.005) # 5ms | |
self.assertEqual(non_drum_metric['onset_f'], 1.0) | |
# yapf: enable | |
if __name__ == '__main__': | |
unittest.main() | |