Spaces:
Build error
Build error
# Copyright 2022 The MT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utilities for transcription metrics.""" | |
import collections | |
import functools | |
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar | |
from mt3 import event_codec | |
from mt3 import note_sequences | |
from mt3 import run_length_encoding | |
import note_seq | |
import numpy as np | |
import pretty_midi | |
import sklearn | |
S = TypeVar('S') | |
T = TypeVar('T') | |
CombineExamplesFunctionType = Callable[[Sequence[Mapping[str, Any]]], | |
Mapping[str, Any]] | |
def _group_predictions_by_id( | |
predictions: Sequence[Mapping[str, T]] | |
) -> Mapping[str, Sequence[T]]: | |
predictions_by_id = collections.defaultdict(list) | |
for pred in predictions: | |
predictions_by_id[pred['unique_id']].append(pred) | |
return predictions_by_id | |
def combine_predictions_by_id( | |
predictions: Sequence[Mapping[str, Any]], | |
combine_predictions_fn: CombineExamplesFunctionType | |
) -> Mapping[str, Mapping[str, Any]]: | |
"""Concatenate predicted examples, grouping by ID and sorting by time.""" | |
predictions_by_id = _group_predictions_by_id(predictions) | |
return { | |
id: combine_predictions_fn(preds) | |
for id, preds in predictions_by_id.items() | |
} | |
def decode_and_combine_predictions( | |
predictions: Sequence[Mapping[str, Any]], | |
init_state_fn: Callable[[], S], | |
begin_segment_fn: Callable[[S], None], | |
decode_tokens_fn: Callable[[S, Sequence[int], int, Optional[int]], | |
Tuple[int, int]], | |
flush_state_fn: Callable[[S], T] | |
) -> Tuple[T, int, int]: | |
"""Decode and combine a sequence of predictions to a full result. | |
For time-based events, this usually means concatenation. | |
Args: | |
predictions: List of predictions, each of which is a dictionary containing | |
estimated tokens ('est_tokens') and start time ('start_time') fields. | |
init_state_fn: Function that takes no arguments and returns an initial | |
decoding state. | |
begin_segment_fn: Function that updates the decoding state at the beginning | |
of a segment. | |
decode_tokens_fn: Function that takes a decoding state, estimated tokens | |
(for a single segment), start time, and max time, and processes the | |
tokens, updating the decoding state in place. Also returns the number of | |
invalid and dropped events for the segment. | |
flush_state_fn: Function that flushes the final decoding state into the | |
result. | |
Returns: | |
result: The full combined decoding. | |
total_invalid_events: Total number of invalid event tokens across all | |
predictions. | |
total_dropped_events: Total number of dropped event tokens across all | |
predictions. | |
""" | |
sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time']) | |
state = init_state_fn() | |
total_invalid_events = 0 | |
total_dropped_events = 0 | |
for pred_idx, pred in enumerate(sorted_predictions): | |
begin_segment_fn(state) | |
# Depending on the audio token hop length, each symbolic token could be | |
# associated with multiple audio frames. Since we split up the audio frames | |
# into segments for prediction, this could lead to overlap. To prevent | |
# overlap issues, ensure that the current segment does not make any | |
# predictions for the time period covered by the subsequent segment. | |
max_decode_time = None | |
if pred_idx < len(sorted_predictions) - 1: | |
max_decode_time = sorted_predictions[pred_idx + 1]['start_time'] | |
invalid_events, dropped_events = decode_tokens_fn( | |
state, pred['est_tokens'], pred['start_time'], max_decode_time) | |
total_invalid_events += invalid_events | |
total_dropped_events += dropped_events | |
return flush_state_fn(state), total_invalid_events, total_dropped_events | |
def event_predictions_to_ns( | |
predictions: Sequence[Mapping[str, Any]], codec: event_codec.Codec, | |
encoding_spec: note_sequences.NoteEncodingSpecType | |
) -> Mapping[str, Any]: | |
"""Convert a sequence of predictions to a combined NoteSequence.""" | |
ns, total_invalid_events, total_dropped_events = decode_and_combine_predictions( | |
predictions=predictions, | |
init_state_fn=encoding_spec.init_decoding_state_fn, | |
begin_segment_fn=encoding_spec.begin_decoding_segment_fn, | |
decode_tokens_fn=functools.partial( | |
run_length_encoding.decode_events, | |
codec=codec, | |
decode_event_fn=encoding_spec.decode_event_fn), | |
flush_state_fn=encoding_spec.flush_decoding_state_fn) | |
# Also concatenate raw inputs from all predictions. | |
sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time']) | |
raw_inputs = np.concatenate( | |
[pred['raw_inputs'] for pred in sorted_predictions], axis=0) | |
start_times = [pred['start_time'] for pred in sorted_predictions] | |
return { | |
'raw_inputs': raw_inputs, | |
'start_times': start_times, | |
'est_ns': ns, | |
'est_invalid_events': total_invalid_events, | |
'est_dropped_events': total_dropped_events, | |
} | |
def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float, | |
is_drum: bool): | |
"""Convert NoteSequence to pianoroll through pretty_midi.""" | |
for note in ns.notes: | |
if is_drum or note.end_time - note.start_time < 0.05: | |
# Give all drum notes a fixed length, and all others a min length | |
note.end_time = note.start_time + 0.05 | |
pm = note_seq.note_sequence_to_pretty_midi(ns) | |
end_time = pm.get_end_time() | |
cc = [ | |
# all sound off | |
pretty_midi.ControlChange(number=120, value=0, time=end_time), | |
# all notes off | |
pretty_midi.ControlChange(number=123, value=0, time=end_time) | |
] | |
pm.instruments[0].control_changes = cc | |
if is_drum: | |
# If inst.is_drum is set, pretty_midi will return an all zero pianoroll. | |
for inst in pm.instruments: | |
inst.is_drum = False | |
pianoroll = pm.get_piano_roll(fs=fps) | |
return pianoroll | |
def frame_metrics(ref_pianoroll: np.ndarray, | |
est_pianoroll: np.ndarray, | |
velocity_threshold: int) -> Tuple[float, float, float]: | |
"""Frame Precision, Recall, and F1.""" | |
# Pad to same length | |
if ref_pianoroll.shape[1] > est_pianoroll.shape[1]: | |
diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1] | |
est_pianoroll = np.pad(est_pianoroll, [(0, 0), (0, diff)], mode='constant') | |
elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]: | |
diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1] | |
ref_pianoroll = np.pad(ref_pianoroll, [(0, 0), (0, diff)], mode='constant') | |
# For ref, remove any notes that are too quiet (consistent with Cerberus.) | |
ref_frames_bool = ref_pianoroll > velocity_threshold | |
# For est, keep all predicted notes. | |
est_frames_bool = est_pianoroll > 0 | |
precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support( | |
ref_frames_bool.flatten(), | |
est_frames_bool.flatten(), | |
labels=[True, False]) | |
return precision[0], recall[0], f1[0] | |