juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transcription metrics."""
import collections
import copy
import functools
from typing import Any, Iterable, Mapping, Optional, Sequence
import mir_eval
from mt3 import event_codec
from mt3 import metrics_utils
from mt3 import note_sequences
from mt3 import spectrograms
from mt3 import summaries
from mt3 import vocabularies
import note_seq
import numpy as np
import seqio
def _program_aware_note_scores(
ref_ns: note_seq.NoteSequence,
est_ns: note_seq.NoteSequence,
granularity_type: str
) -> Mapping[str, float]:
"""Compute precision/recall/F1 for notes taking program into account.
For non-drum tracks, uses onsets and offsets. For drum tracks, uses onsets
only. Applies MIDI program map of specified granularity type.
Args:
ref_ns: Reference NoteSequence with ground truth labels.
est_ns: Estimated NoteSequence.
granularity_type: String key in vocabularies.PROGRAM_GRANULARITIES dict.
Returns:
A dictionary containing precision, recall, and F1 score.
"""
program_map_fn = vocabularies.PROGRAM_GRANULARITIES[
granularity_type].program_map_fn
ref_ns = copy.deepcopy(ref_ns)
for note in ref_ns.notes:
if not note.is_drum:
note.program = program_map_fn(note.program)
est_ns = copy.deepcopy(est_ns)
for note in est_ns.notes:
if not note.is_drum:
note.program = program_map_fn(note.program)
program_and_is_drum_tuples = (
set((note.program, note.is_drum) for note in ref_ns.notes) |
set((note.program, note.is_drum) for note in est_ns.notes)
)
drum_precision_sum = 0.0
drum_precision_count = 0
drum_recall_sum = 0.0
drum_recall_count = 0
nondrum_precision_sum = 0.0
nondrum_precision_count = 0
nondrum_recall_sum = 0.0
nondrum_recall_count = 0
for program, is_drum in program_and_is_drum_tuples:
est_track = note_sequences.extract_track(est_ns, program, is_drum)
ref_track = note_sequences.extract_track(ref_ns, program, is_drum)
est_intervals, est_pitches, unused_est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_track))
ref_intervals, ref_pitches, unused_ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_track))
args = {
'ref_intervals': ref_intervals, 'ref_pitches': ref_pitches,
'est_intervals': est_intervals, 'est_pitches': est_pitches
}
if is_drum:
args['offset_ratio'] = None
precision, recall, unused_f_measure, unused_avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(**args))
if is_drum:
drum_precision_sum += precision * len(est_intervals)
drum_precision_count += len(est_intervals)
drum_recall_sum += recall * len(ref_intervals)
drum_recall_count += len(ref_intervals)
else:
nondrum_precision_sum += precision * len(est_intervals)
nondrum_precision_count += len(est_intervals)
nondrum_recall_sum += recall * len(ref_intervals)
nondrum_recall_count += len(ref_intervals)
precision_sum = drum_precision_sum + nondrum_precision_sum
precision_count = drum_precision_count + nondrum_precision_count
recall_sum = drum_recall_sum + nondrum_recall_sum
recall_count = drum_recall_count + nondrum_recall_count
precision = (precision_sum / precision_count) if precision_count else 0
recall = (recall_sum / recall_count) if recall_count else 0
f_measure = mir_eval.util.f_measure(precision, recall)
drum_precision = ((drum_precision_sum / drum_precision_count)
if drum_precision_count else 0)
drum_recall = ((drum_recall_sum / drum_recall_count)
if drum_recall_count else 0)
drum_f_measure = mir_eval.util.f_measure(drum_precision, drum_recall)
nondrum_precision = ((nondrum_precision_sum / nondrum_precision_count)
if nondrum_precision_count else 0)
nondrum_recall = ((nondrum_recall_sum / nondrum_recall_count)
if nondrum_recall_count else 0)
nondrum_f_measure = mir_eval.util.f_measure(nondrum_precision, nondrum_recall)
return {
f'Onset + offset + program precision ({granularity_type})': precision,
f'Onset + offset + program recall ({granularity_type})': recall,
f'Onset + offset + program F1 ({granularity_type})': f_measure,
f'Drum onset precision ({granularity_type})': drum_precision,
f'Drum onset recall ({granularity_type})': drum_recall,
f'Drum onset F1 ({granularity_type})': drum_f_measure,
f'Nondrum onset + offset + program precision ({granularity_type})':
nondrum_precision,
f'Nondrum onset + offset + program recall ({granularity_type})':
nondrum_recall,
f'Nondrum onset + offset + program F1 ({granularity_type})':
nondrum_f_measure
}
def _note_onset_tolerance_sweep(
ref_ns: note_seq.NoteSequence, est_ns: note_seq.NoteSequence,
tolerances: Iterable[float] = (0.01, 0.02, 0.05, 0.1, 0.2, 0.5)
) -> Mapping[str, float]:
"""Compute note precision/recall/F1 across a range of tolerances."""
est_intervals, est_pitches, unused_est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
ref_intervals, ref_pitches, unused_ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
scores = {}
for tol in tolerances:
precision, recall, f_measure, _ = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals, ref_pitches=ref_pitches,
est_intervals=est_intervals, est_pitches=est_pitches,
onset_tolerance=tol, offset_min_tolerance=tol))
scores[f'Onset + offset precision ({tol})'] = precision
scores[f'Onset + offset recall ({tol})'] = recall
scores[f'Onset + offset F1 ({tol})'] = f_measure
return scores
def transcription_metrics(
targets: Sequence[Mapping[str, Any]],
predictions: Sequence[Mapping[str, Any]],
codec: event_codec.Codec,
spectrogram_config: spectrograms.SpectrogramConfig,
onsets_only: bool,
use_ties: bool,
track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
num_summary_examples: int = 5,
frame_fps: float = 62.5,
frame_velocity_threshold: int = 30,
) -> Mapping[str, seqio.metrics.MetricValue]:
"""Compute mir_eval transcription metrics."""
if onsets_only and use_ties:
raise ValueError('Ties not compatible with onset-only transcription.')
if onsets_only:
encoding_spec = note_sequences.NoteOnsetEncodingSpec
elif not use_ties:
encoding_spec = note_sequences.NoteEncodingSpec
else:
encoding_spec = note_sequences.NoteEncodingWithTiesSpec
# The first target for each full example contains the NoteSequence; just
# organize by ID.
full_targets = {}
for target in targets:
if target['ref_ns']:
full_targets[target['unique_id']] = {'ref_ns': target['ref_ns']}
# Gather all predictions for the same ID and concatenate them in time order,
# to construct full-length predictions.
full_predictions = metrics_utils.combine_predictions_by_id(
predictions=predictions,
combine_predictions_fn=functools.partial(
metrics_utils.event_predictions_to_ns,
codec=codec,
encoding_spec=encoding_spec))
assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
full_target_prediction_pairs = [
(full_targets[id], full_predictions[id])
for id in sorted(full_targets.keys())
]
scores = collections.defaultdict(list)
all_track_pianorolls = collections.defaultdict(list)
for target, prediction in full_target_prediction_pairs:
scores['Invalid events'].append(prediction['est_invalid_events'])
scores['Dropped events'].append(prediction['est_dropped_events'])
def remove_drums(ns):
ns_drumless = note_seq.NoteSequence()
ns_drumless.CopyFrom(ns)
del ns_drumless.notes[:]
ns_drumless.notes.extend([note for note in ns.notes if not note.is_drum])
return ns_drumless
est_ns_drumless = remove_drums(prediction['est_ns'])
ref_ns_drumless = remove_drums(target['ref_ns'])
# Whether or not there are separate tracks, compute metrics for the full
# NoteSequence minus drums.
est_tracks = [est_ns_drumless]
ref_tracks = [ref_ns_drumless]
use_track_offsets = [not onsets_only]
use_track_velocities = [not onsets_only]
track_instrument_names = ['']
if track_specs is not None:
# Compute transcription metrics separately for each track.
for spec in track_specs:
est_tracks.append(note_sequences.extract_track(
prediction['est_ns'], spec.program, spec.is_drum))
ref_tracks.append(note_sequences.extract_track(
target['ref_ns'], spec.program, spec.is_drum))
use_track_offsets.append(not onsets_only and not spec.is_drum)
use_track_velocities.append(not onsets_only)
track_instrument_names.append(spec.name)
for est_ns, ref_ns, use_offsets, use_velocities, instrument_name in zip(
est_tracks, ref_tracks, use_track_offsets, use_track_velocities,
track_instrument_names):
track_scores = {}
est_intervals, est_pitches, est_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
ref_intervals, ref_pitches, ref_velocities = (
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
# Precision / recall / F1 using onsets (and pitches) only.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
est_intervals=est_intervals,
est_pitches=est_pitches,
offset_ratio=None))
del avg_overlap_ratio
track_scores['Onset precision'] = precision
track_scores['Onset recall'] = recall
track_scores['Onset F1'] = f_measure
if use_offsets:
# Precision / recall / F1 using onsets and offsets.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
est_intervals=est_intervals,
est_pitches=est_pitches))
del avg_overlap_ratio
track_scores['Onset + offset precision'] = precision
track_scores['Onset + offset recall'] = recall
track_scores['Onset + offset F1'] = f_measure
if use_velocities:
# Precision / recall / F1 using onsets and velocities (no offsets).
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription_velocity.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
ref_velocities=ref_velocities,
est_intervals=est_intervals,
est_pitches=est_pitches,
est_velocities=est_velocities,
offset_ratio=None))
track_scores['Onset + velocity precision'] = precision
track_scores['Onset + velocity recall'] = recall
track_scores['Onset + velocity F1'] = f_measure
if use_offsets and use_velocities:
# Precision / recall / F1 using onsets, offsets, and velocities.
precision, recall, f_measure, avg_overlap_ratio = (
mir_eval.transcription_velocity.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
ref_velocities=ref_velocities,
est_intervals=est_intervals,
est_pitches=est_pitches,
est_velocities=est_velocities))
track_scores['Onset + offset + velocity precision'] = precision
track_scores['Onset + offset + velocity recall'] = recall
track_scores['Onset + offset + velocity F1'] = f_measure
# Calculate framewise metrics.
is_drum = all([n.is_drum for n in ref_ns.notes])
ref_pr = metrics_utils.get_prettymidi_pianoroll(
ref_ns, frame_fps, is_drum=is_drum)
est_pr = metrics_utils.get_prettymidi_pianoroll(
est_ns, frame_fps, is_drum=is_drum)
all_track_pianorolls[instrument_name].append((est_pr, ref_pr))
frame_precision, frame_recall, frame_f1 = metrics_utils.frame_metrics(
ref_pr, est_pr, velocity_threshold=frame_velocity_threshold)
track_scores['Frame Precision'] = frame_precision
track_scores['Frame Recall'] = frame_recall
track_scores['Frame F1'] = frame_f1
for metric_name, metric_value in track_scores.items():
if instrument_name:
scores[f'{instrument_name}/{metric_name}'].append(metric_value)
else:
scores[metric_name].append(metric_value)
# Add program-aware note metrics for all program granularities.
# Note that this interacts with the training program granularity; in
# particular granularities *higher* than the training granularity are likely
# to have poor metrics.
for granularity_type in vocabularies.PROGRAM_GRANULARITIES:
for name, score in _program_aware_note_scores(
target['ref_ns'], prediction['est_ns'],
granularity_type=granularity_type).items():
scores[name].append(score)
# Add (non-program-aware) note metrics across a range of onset/offset
# tolerances.
for name, score in _note_onset_tolerance_sweep(
ref_ns=ref_ns_drumless, est_ns=est_ns_drumless).items():
scores[name].append(score)
mean_scores = {k: np.mean(v) for k, v in scores.items()}
score_histograms = {'%s (hist)' % k: seqio.metrics.Histogram(np.array(v))
for k, v in scores.items()}
# Pick several examples to summarize.
targets_to_summarize, predictions_to_summarize = zip(
*full_target_prediction_pairs[:num_summary_examples])
# Compute audio summaries.
audio_summaries = summaries.audio_summaries(
targets=targets_to_summarize,
predictions=predictions_to_summarize,
spectrogram_config=spectrogram_config)
# Compute transcription summaries.
transcription_summaries = summaries.transcription_summaries(
targets=targets_to_summarize,
predictions=predictions_to_summarize,
spectrogram_config=spectrogram_config,
ns_feature_suffix='ns',
track_specs=track_specs)
pianorolls_to_summarize = {
k: v[:num_summary_examples] for k, v in all_track_pianorolls.items()
}
prettymidi_pianoroll_summaries = summaries.prettymidi_pianoroll(
pianorolls_to_summarize, fps=frame_fps)
return {
**mean_scores,
**score_histograms,
**audio_summaries,
**transcription_summaries,
**prettymidi_pianoroll_summaries,
}