Spaces:
Build error
Build error
# Copyright 2022 The MT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Functions for MT3 inference.""" | |
import functools | |
import json | |
from typing import Any, Optional, Sequence | |
import gin | |
from mt3 import metrics_utils | |
from mt3 import note_sequences | |
from mt3 import tasks | |
from mt3 import vocabularies | |
import note_seq | |
import seqio | |
import tensorflow as tf | |
def write_inferences_to_file( | |
path: str, | |
inferences: Sequence[Any], | |
task_ds: tf.data.Dataset, | |
mode: str, | |
vocabulary: Optional[seqio.Vocabulary] = None, | |
vocab_config=gin.REQUIRED, | |
onsets_only=gin.REQUIRED, | |
use_ties=gin.REQUIRED) -> None: | |
"""Writes model predictions, ground truth transcriptions, and input audio. | |
For now this only works for transcription tasks with ties. | |
Args: | |
path: File path to write to. | |
inferences: Model inferences, output of predict_batch. | |
task_ds: Original task dataset. | |
mode: Prediction mode; must be 'predict' as 'score' is not supported. | |
vocabulary: Task output vocabulary. | |
vocab_config: Vocabulary config object. | |
onsets_only: If True, only predict onsets. | |
use_ties: If True, use "tie" representation. | |
""" | |
if mode == 'score': | |
raise ValueError('`score` mode currently not supported in MT3') | |
if not vocabulary: | |
raise ValueError('`vocabulary` parameter required in `predict` mode') | |
if onsets_only and use_ties: | |
raise ValueError('ties not compatible with onset-only transcription') | |
if onsets_only: | |
encoding_spec = note_sequences.NoteOnsetEncodingSpec | |
elif not use_ties: | |
encoding_spec = note_sequences.NoteEncodingSpec | |
else: | |
encoding_spec = note_sequences.NoteEncodingWithTiesSpec | |
codec = vocabularies.build_codec(vocab_config) | |
targets = [] | |
predictions = [] | |
for inp, output in zip(task_ds.as_numpy_iterator(), inferences): | |
tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy()) | |
start_time = inp['input_times'][0] | |
# Round down to nearest symbolic token step. | |
start_time -= start_time % (1 / codec.steps_per_second) | |
targets.append({ | |
'unique_id': inp['unique_id'][0], | |
'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None, | |
}) | |
predictions.append({ | |
'unique_id': inp['unique_id'][0], | |
'est_tokens': tokens, | |
'start_time': start_time, | |
# Input audio is not part of the "prediction" but the below call to | |
# metrics_utils.event_predictions_to_ns handles the concatenation. | |
'raw_inputs': inp['raw_inputs'] | |
}) | |
# The first target for each full example contains the NoteSequence; just | |
# organize by ID. | |
full_targets = {} | |
for target in targets: | |
if target['ref_ns']: | |
full_targets[target['unique_id']] = { | |
'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns']) | |
} | |
full_predictions = metrics_utils.combine_predictions_by_id( | |
predictions=predictions, | |
combine_predictions_fn=functools.partial( | |
metrics_utils.event_predictions_to_ns, | |
codec=codec, | |
encoding_spec=encoding_spec)) | |
assert sorted(full_targets.keys()) == sorted(full_predictions.keys()) | |
full_target_prediction_pairs = [ | |
(full_targets[id], full_predictions[id]) | |
for id in sorted(full_targets.keys()) | |
] | |
def note_to_dict(note): | |
return { | |
'start_time': note.start_time, | |
'end_time': note.end_time, | |
'pitch': note.pitch, | |
'velocity': note.velocity, | |
'program': note.program, | |
'is_drum': note.is_drum | |
} | |
with tf.io.gfile.GFile(path, 'w') as f: | |
for target, prediction in full_target_prediction_pairs: | |
json_dict = { | |
'id': target['ref_ns'].id, | |
'est_notes': | |
[note_to_dict(note) for note in prediction['est_ns'].notes] | |
} | |
json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder) | |
f.write(json_str + '\n') | |