Spaces:
Build error
Build error
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""metrics_test.py: | |
This file contains tests for the following classes: | |
• AMTMetrics | |
""" | |
import unittest | |
import warnings | |
import torch | |
import numpy as np | |
from utils.metrics import AMTMetrics | |
from utils.metrics import compute_track_metrics | |
class TestAMTMetrics(unittest.TestCase): | |
def test_individual_attributes(self): | |
metric = AMTMetrics() | |
# Test updating the metric using .update() method | |
metric.onset_f.update(0.5) | |
# Test updating the metric using __call__() method | |
metric.onset_f(0.5) | |
# Test updating the metric with a weight | |
metric.onset_f(0, weight=1.0) | |
# Test computing the average value of the metric | |
computed_value = metric.onset_f.compute() | |
self.assertAlmostEqual(computed_value, 0.3333333333333333) | |
# Test resetting the metric | |
metric.onset_f.reset() | |
with self.assertWarns(UserWarning): | |
torch._assert(metric.onset_f.compute(), torch.nan) | |
# Test bulk_compute | |
with self.assertWarns(UserWarning): | |
computed_metrics = metric.bulk_compute() | |
def test_bulk_update_and_compute(self): | |
metric = AMTMetrics() | |
# Test bulk_update with values only | |
d1 = {'onset_f': 0.5, 'offset_f': 0.5} | |
metric.bulk_update(d1) | |
# Test bulk_update with values and weights | |
d2 = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} | |
metric.bulk_update(d2) | |
# Test bulk_compute | |
computed_metrics = metric.bulk_compute() | |
# Ensure the 'onset_f' and 'offset_f' keys exist in the computed_metrics dictionary | |
self.assertIn('onset_f', computed_metrics) | |
self.assertIn('offset_f', computed_metrics) | |
# Check the computed values | |
self.assertAlmostEqual(computed_metrics['onset_f'], 0.5) | |
self.assertAlmostEqual(computed_metrics['offset_f'], 0.5) | |
def test_compute_track_metrics_singing(self): | |
from config.vocabulary import SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS | |
from utils.event2note import note_event2note | |
ref_notes_dict = np.load('extras/examples/singing_notes.npy', allow_pickle=True).tolist() | |
ref_note_events_dict = np.load('extras/examples/singing_note_events.npy', allow_pickle=True).tolist() | |
est_notes, _ = note_event2note(ref_note_events_dict['note_events']) | |
ref_notes = ref_notes_dict['notes'] | |
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in SINGING_SOLO_CLASS.keys()]) | |
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, | |
ref_notes, | |
eval_vocab=SINGING_SOLO_CLASS, | |
eval_drum_vocab=None, | |
onset_tolerance=0.05) | |
metric.bulk_update(drum_metric) | |
metric.bulk_update(non_drum_metric) | |
metric.bulk_update(instr_metric) | |
computed_metrics = metric.bulk_compute() | |
cnt = 0 | |
for k, v in computed_metrics.items(): | |
if 'Singing Voice' in k: | |
self.assertEqual(v, 1.0) | |
cnt += 1 | |
self.assertEqual(cnt, 6) | |
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in GM_INSTR_CLASS_PLUS.keys()]) | |
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, | |
ref_notes, | |
eval_vocab=GM_INSTR_CLASS_PLUS, | |
eval_drum_vocab=None, | |
onset_tolerance=0.05) | |
metric.bulk_update(drum_metric) | |
metric.bulk_update(non_drum_metric) | |
metric.bulk_update(instr_metric) | |
computed_metrics = metric.bulk_compute() | |
cnt = 0 | |
for k, v in computed_metrics.items(): | |
if 'Singing Voice' in k: | |
self.assertEqual(v, 1.0) | |
cnt += 1 | |
self.assertEqual(cnt, 6) | |
if __name__ == '__main__': | |
unittest.main() | |