Spaces:
Build error
Build error
File size: 15,414 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 |
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transcription metrics."""
import collections
import copy
import functools
from typing import Any, Iterable, Mapping, Optional, Sequence
import mir_eval
from mt3 import event_codec
from mt3 import metrics_utils
from mt3 import note_sequences
from mt3 import spectrograms
from mt3 import summaries
from mt3 import vocabularies
import note_seq
import numpy as np
import seqio
def _program_aware_note_scores(
ref_ns: note_seq.NoteSequence,
est_ns: note_seq.NoteSequence,
granularity_type: str
) -> Mapping[str, float]:
"""Compute precision/recall/F1 for notes taking program into account.
For non-drum tracks, uses onsets and offsets. For drum tracks, uses onsets
only. Applies MIDI program map of specified granularity type.
Args:
ref_ns: Reference NoteSequence with ground truth labels.
est_ns: Estimated NoteSequence.
granularity_type: String key in vocabularies.PROGRAM_GRANULARITIES dict.
Returns:
A dictionary containing precision, recall, and F1 score.
"""
program_map_fn = vocabularies.PROGRAM_GRANULARITIES[
granularity_type].program_map_fn
ref_ns = copy.deepcopy(ref_ns)
for note in ref_ns.notes:
if not note.is_drum:
note.program = program_map_fn(note.program)
est_ns = copy.deepcopy(est_ns)
for note in est_ns.notes:
if not note.is_drum:
note.program = program_map_fn(note.program)
program_and_is_drum_tuples = (
set((note.program, note.is_drum) for note in ref_ns.notes) |
set((note.program, note.is_drum) for note in est_ns.notes)
)
drum_precision_sum = 0.0
drum_precision_count = 0
drum_recall_sum = 0.0
drum_recall_count = 0
nondrum_precision_sum = 0.0
nondrum_precision_count = 0
nondrum_recall_sum = 0.0
nondrum_recall_count = 0
for program, is_drum in program_and_is_drum_tuples:
est_track = note_sequences.extract_track(est_ns, program, is_drum)
ref_track = note_sequences.extract_track(ref_ns, program, is_drum)
est_intervals, est_pitches, unused_est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_track))
ref_intervals, ref_pitches, unused_ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_track))
args = {
'ref_intervals': ref_intervals, 'ref_pitches': ref_pitches,
'est_intervals': est_intervals, 'est_pitches': est_pitches
}
if is_drum:
args['offset_ratio'] = None
precision, recall, unused_f_measure, unused_avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(**args))
if is_drum:
drum_precision_sum += precision * len(est_intervals)
drum_precision_count += len(est_intervals)
drum_recall_sum += recall * len(ref_intervals)
drum_recall_count += len(ref_intervals)
else:
nondrum_precision_sum += precision * len(est_intervals)
nondrum_precision_count += len(est_intervals)
nondrum_recall_sum += recall * len(ref_intervals)
nondrum_recall_count += len(ref_intervals)
precision_sum = drum_precision_sum + nondrum_precision_sum
precision_count = drum_precision_count + nondrum_precision_count
recall_sum = drum_recall_sum + nondrum_recall_sum
recall_count = drum_recall_count + nondrum_recall_count
precision = (precision_sum / precision_count) if precision_count else 0
recall = (recall_sum / recall_count) if recall_count else 0
f_measure = mir_eval.util.f_measure(precision, recall)
drum_precision = ((drum_precision_sum / drum_precision_count)
if drum_precision_count else 0)
drum_recall = ((drum_recall_sum / drum_recall_count)
if drum_recall_count else 0)
drum_f_measure = mir_eval.util.f_measure(drum_precision, drum_recall)
nondrum_precision = ((nondrum_precision_sum / nondrum_precision_count)
if nondrum_precision_count else 0)
nondrum_recall = ((nondrum_recall_sum / nondrum_recall_count)
if nondrum_recall_count else 0)
nondrum_f_measure = mir_eval.util.f_measure(nondrum_precision, nondrum_recall)
return {
f'Onset + offset + program precision ({granularity_type})': precision,
f'Onset + offset + program recall ({granularity_type})': recall,
f'Onset + offset + program F1 ({granularity_type})': f_measure,
f'Drum onset precision ({granularity_type})': drum_precision,
f'Drum onset recall ({granularity_type})': drum_recall,
f'Drum onset F1 ({granularity_type})': drum_f_measure,
f'Nondrum onset + offset + program precision ({granularity_type})':
nondrum_precision,
f'Nondrum onset + offset + program recall ({granularity_type})':
nondrum_recall,
f'Nondrum onset + offset + program F1 ({granularity_type})':
nondrum_f_measure
}
def _note_onset_tolerance_sweep(
ref_ns: note_seq.NoteSequence, est_ns: note_seq.NoteSequence,
tolerances: Iterable[float] = (0.01, 0.02, 0.05, 0.1, 0.2, 0.5)
) -> Mapping[str, float]:
"""Compute note precision/recall/F1 across a range of tolerances."""
est_intervals, est_pitches, unused_est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
ref_intervals, ref_pitches, unused_ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
scores = {}
for tol in tolerances:
precision, recall, f_measure, _ = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals, ref_pitches=ref_pitches,
est_intervals=est_intervals, est_pitches=est_pitches,
onset_tolerance=tol, offset_min_tolerance=tol))
scores[f'Onset + offset precision ({tol})'] = precision
scores[f'Onset + offset recall ({tol})'] = recall
scores[f'Onset + offset F1 ({tol})'] = f_measure
return scores
def transcription_metrics(
targets: Sequence[Mapping[str, Any]],
predictions: Sequence[Mapping[str, Any]],
codec: event_codec.Codec,
spectrogram_config: spectrograms.SpectrogramConfig,
onsets_only: bool,
use_ties: bool,
track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
num_summary_examples: int = 5,
frame_fps: float = 62.5,
frame_velocity_threshold: int = 30,
) -> Mapping[str, seqio.metrics.MetricValue]:
"""Compute mir_eval transcription metrics."""
if onsets_only and use_ties:
raise ValueError('Ties not compatible with onset-only transcription.')
if onsets_only:
encoding_spec = note_sequences.NoteOnsetEncodingSpec
elif not use_ties:
encoding_spec = note_sequences.NoteEncodingSpec
else:
encoding_spec = note_sequences.NoteEncodingWithTiesSpec
# The first target for each full example contains the NoteSequence; just
# organize by ID.
full_targets = {}
for target in targets:
if target['ref_ns']:
full_targets[target['unique_id']] = {'ref_ns': target['ref_ns']}
# Gather all predictions for the same ID and concatenate them in time order,
# to construct full-length predictions.
full_predictions = metrics_utils.combine_predictions_by_id(
predictions=predictions,
combine_predictions_fn=functools.partial(
metrics_utils.event_predictions_to_ns,
codec=codec,
encoding_spec=encoding_spec))
assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
full_target_prediction_pairs = [
(full_targets[id], full_predictions[id])
for id in sorted(full_targets.keys())
]
scores = collections.defaultdict(list)
all_track_pianorolls = collections.defaultdict(list)
for target, prediction in full_target_prediction_pairs:
scores['Invalid events'].append(prediction['est_invalid_events'])
scores['Dropped events'].append(prediction['est_dropped_events'])
def remove_drums(ns):
ns_drumless = note_seq.NoteSequence()
ns_drumless.CopyFrom(ns)
del ns_drumless.notes[:]
ns_drumless.notes.extend([note for note in ns.notes if not note.is_drum])
return ns_drumless
est_ns_drumless = remove_drums(prediction['est_ns'])
ref_ns_drumless = remove_drums(target['ref_ns'])
# Whether or not there are separate tracks, compute metrics for the full
# NoteSequence minus drums.
est_tracks = [est_ns_drumless]
ref_tracks = [ref_ns_drumless]
use_track_offsets = [not onsets_only]
use_track_velocities = [not onsets_only]
track_instrument_names = ['']
if track_specs is not None:
# Compute transcription metrics separately for each track.
for spec in track_specs:
est_tracks.append(note_sequences.extract_track(
prediction['est_ns'], spec.program, spec.is_drum))
ref_tracks.append(note_sequences.extract_track(
target['ref_ns'], spec.program, spec.is_drum))
use_track_offsets.append(not onsets_only and not spec.is_drum)
use_track_velocities.append(not onsets_only)
track_instrument_names.append(spec.name)
for est_ns, ref_ns, use_offsets, use_velocities, instrument_name in zip(
est_tracks, ref_tracks, use_track_offsets, use_track_velocities,
track_instrument_names):
track_scores = {}
est_intervals, est_pitches, est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
ref_intervals, ref_pitches, ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
# Precision / recall / F1 using onsets (and pitches) only.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
est_intervals=est_intervals,
est_pitches=est_pitches,
offset_ratio=None))
del avg_overlap_ratio
track_scores['Onset precision'] = precision
track_scores['Onset recall'] = recall
track_scores['Onset F1'] = f_measure
if use_offsets:
# Precision / recall / F1 using onsets and offsets.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
est_intervals=est_intervals,
est_pitches=est_pitches))
del avg_overlap_ratio
track_scores['Onset + offset precision'] = precision
track_scores['Onset + offset recall'] = recall
track_scores['Onset + offset F1'] = f_measure
if use_velocities:
# Precision / recall / F1 using onsets and velocities (no offsets).
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription_velocity.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
ref_velocities=ref_velocities,
est_intervals=est_intervals,
est_pitches=est_pitches,
est_velocities=est_velocities,
offset_ratio=None))
track_scores['Onset + velocity precision'] = precision
track_scores['Onset + velocity recall'] = recall
track_scores['Onset + velocity F1'] = f_measure
if use_offsets and use_velocities:
# Precision / recall / F1 using onsets, offsets, and velocities.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription_velocity.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
ref_velocities=ref_velocities,
est_intervals=est_intervals,
est_pitches=est_pitches,
est_velocities=est_velocities))
track_scores['Onset + offset + velocity precision'] = precision
track_scores['Onset + offset + velocity recall'] = recall
track_scores['Onset + offset + velocity F1'] = f_measure
# Calculate framewise metrics.
is_drum = all([n.is_drum for n in ref_ns.notes])
ref_pr = metrics_utils.get_prettymidi_pianoroll(
ref_ns, frame_fps, is_drum=is_drum)
est_pr = metrics_utils.get_prettymidi_pianoroll(
est_ns, frame_fps, is_drum=is_drum)
all_track_pianorolls[instrument_name].append((est_pr, ref_pr))
frame_precision, frame_recall, frame_f1 = metrics_utils.frame_metrics(
ref_pr, est_pr, velocity_threshold=frame_velocity_threshold)
track_scores['Frame Precision'] = frame_precision
track_scores['Frame Recall'] = frame_recall
track_scores['Frame F1'] = frame_f1
for metric_name, metric_value in track_scores.items():
if instrument_name:
scores[f'{instrument_name}/{metric_name}'].append(metric_value)
else:
scores[metric_name].append(metric_value)
# Add program-aware note metrics for all program granularities.
# Note that this interacts with the training program granularity; in
# particular granularities *higher* than the training granularity are likely
# to have poor metrics.
for granularity_type in vocabularies.PROGRAM_GRANULARITIES:
for name, score in _program_aware_note_scores(
target['ref_ns'], prediction['est_ns'],
granularity_type=granularity_type).items():
scores[name].append(score)
# Add (non-program-aware) note metrics across a range of onset/offset
# tolerances.
for name, score in _note_onset_tolerance_sweep(
ref_ns=ref_ns_drumless, est_ns=est_ns_drumless).items():
scores[name].append(score)
mean_scores = {k: np.mean(v) for k, v in scores.items()}
score_histograms = {'%s (hist)' % k: seqio.metrics.Histogram(np.array(v))
for k, v in scores.items()}
# Pick several examples to summarize.
targets_to_summarize, predictions_to_summarize = zip(
*full_target_prediction_pairs[:num_summary_examples])
# Compute audio summaries.
audio_summaries = summaries.audio_summaries(
targets=targets_to_summarize,
predictions=predictions_to_summarize,
spectrogram_config=spectrogram_config)
# Compute transcription summaries.
transcription_summaries = summaries.transcription_summaries(
targets=targets_to_summarize,
predictions=predictions_to_summarize,
spectrogram_config=spectrogram_config,
ns_feature_suffix='ns',
track_specs=track_specs)
pianorolls_to_summarize = {
k: v[:num_summary_examples] for k, v in all_track_pianorolls.items()
}
prettymidi_pianoroll_summaries = summaries.prettymidi_pianoroll(
pianorolls_to_summarize, fps=frame_fps)
return {
**mean_scores,
**score_histograms,
**audio_summaries,
**transcription_summaries,
**prettymidi_pianoroll_summaries,
}
|