Spaces:
Build error
Build error
File size: 4,386 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for MT3 inference."""
import functools
import json
from typing import Any, Optional, Sequence
import gin
from mt3 import metrics_utils
from mt3 import note_sequences
from mt3 import tasks
from mt3 import vocabularies
import note_seq
import seqio
import tensorflow as tf
def write_inferences_to_file(
path: str,
inferences: Sequence[Any],
task_ds: tf.data.Dataset,
mode: str,
vocabulary: Optional[seqio.Vocabulary] = None,
vocab_config=gin.REQUIRED,
onsets_only=gin.REQUIRED,
use_ties=gin.REQUIRED) -> None:
"""Writes model predictions, ground truth transcriptions, and input audio.
For now this only works for transcription tasks with ties.
Args:
path: File path to write to.
inferences: Model inferences, output of predict_batch.
task_ds: Original task dataset.
mode: Prediction mode; must be 'predict' as 'score' is not supported.
vocabulary: Task output vocabulary.
vocab_config: Vocabulary config object.
onsets_only: If True, only predict onsets.
use_ties: If True, use "tie" representation.
"""
if mode == 'score':
raise ValueError('`score` mode currently not supported in MT3')
if not vocabulary:
raise ValueError('`vocabulary` parameter required in `predict` mode')
if onsets_only and use_ties:
raise ValueError('ties not compatible with onset-only transcription')
if onsets_only:
encoding_spec = note_sequences.NoteOnsetEncodingSpec
elif not use_ties:
encoding_spec = note_sequences.NoteEncodingSpec
else:
encoding_spec = note_sequences.NoteEncodingWithTiesSpec
codec = vocabularies.build_codec(vocab_config)
targets = []
predictions = []
for inp, output in zip(task_ds.as_numpy_iterator(), inferences):
tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy())
start_time = inp['input_times'][0]
# Round down to nearest symbolic token step.
start_time -= start_time % (1 / codec.steps_per_second)
targets.append({
'unique_id': inp['unique_id'][0],
'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None,
})
predictions.append({
'unique_id': inp['unique_id'][0],
'est_tokens': tokens,
'start_time': start_time,
# Input audio is not part of the "prediction" but the below call to
# metrics_utils.event_predictions_to_ns handles the concatenation.
'raw_inputs': inp['raw_inputs']
})
# The first target for each full example contains the NoteSequence; just
# organize by ID.
full_targets = {}
for target in targets:
if target['ref_ns']:
full_targets[target['unique_id']] = {
'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns'])
}
full_predictions = metrics_utils.combine_predictions_by_id(
predictions=predictions,
combine_predictions_fn=functools.partial(
metrics_utils.event_predictions_to_ns,
codec=codec,
encoding_spec=encoding_spec))
assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
full_target_prediction_pairs = [
(full_targets[id], full_predictions[id])
for id in sorted(full_targets.keys())
]
def note_to_dict(note):
return {
'start_time': note.start_time,
'end_time': note.end_time,
'pitch': note.pitch,
'velocity': note.velocity,
'program': note.program,
'is_drum': note.is_drum
}
with tf.io.gfile.GFile(path, 'w') as f:
for target, prediction in full_target_prediction_pairs:
json_dict = {
'id': target['ref_ns'].id,
'est_notes':
[note_to_dict(note) for note in prediction['est_ns'].notes]
}
json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder)
f.write(json_str + '\n')
|