File size: 4,386 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for MT3 inference."""

import functools
import json

from typing import Any, Optional, Sequence

import gin

from mt3 import metrics_utils
from mt3 import note_sequences
from mt3 import tasks
from mt3 import vocabularies

import note_seq
import seqio
import tensorflow as tf


def write_inferences_to_file(
    path: str,
    inferences: Sequence[Any],
    task_ds: tf.data.Dataset,
    mode: str,
    vocabulary: Optional[seqio.Vocabulary] = None,
    vocab_config=gin.REQUIRED,
    onsets_only=gin.REQUIRED,
    use_ties=gin.REQUIRED) -> None:
  """Writes model predictions, ground truth transcriptions, and input audio.

  For now this only works for transcription tasks with ties.

  Args:
    path: File path to write to.
    inferences: Model inferences, output of predict_batch.
    task_ds: Original task dataset.
    mode: Prediction mode; must be 'predict' as 'score' is not supported.
    vocabulary: Task output vocabulary.
    vocab_config: Vocabulary config object.
    onsets_only: If True, only predict onsets.
    use_ties: If True, use "tie" representation.
  """
  if mode == 'score':
    raise ValueError('`score` mode currently not supported in MT3')
  if not vocabulary:
    raise ValueError('`vocabulary` parameter required in `predict` mode')

  if onsets_only and use_ties:
    raise ValueError('ties not compatible with onset-only transcription')
  if onsets_only:
    encoding_spec = note_sequences.NoteOnsetEncodingSpec
  elif not use_ties:
    encoding_spec = note_sequences.NoteEncodingSpec
  else:
    encoding_spec = note_sequences.NoteEncodingWithTiesSpec

  codec = vocabularies.build_codec(vocab_config)

  targets = []
  predictions = []

  for inp, output in zip(task_ds.as_numpy_iterator(), inferences):
    tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy())

    start_time = inp['input_times'][0]
    # Round down to nearest symbolic token step.
    start_time -= start_time % (1 / codec.steps_per_second)

    targets.append({
        'unique_id': inp['unique_id'][0],
        'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None,
    })

    predictions.append({
        'unique_id': inp['unique_id'][0],
        'est_tokens': tokens,
        'start_time': start_time,
        # Input audio is not part of the "prediction" but the below call to
        # metrics_utils.event_predictions_to_ns handles the concatenation.
        'raw_inputs': inp['raw_inputs']
    })

  # The first target for each full example contains the NoteSequence; just
  # organize by ID.
  full_targets = {}
  for target in targets:
    if target['ref_ns']:
      full_targets[target['unique_id']] = {
          'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns'])
      }

  full_predictions = metrics_utils.combine_predictions_by_id(
      predictions=predictions,
      combine_predictions_fn=functools.partial(
          metrics_utils.event_predictions_to_ns,
          codec=codec,
          encoding_spec=encoding_spec))

  assert sorted(full_targets.keys()) == sorted(full_predictions.keys())

  full_target_prediction_pairs = [
      (full_targets[id], full_predictions[id])
      for id in sorted(full_targets.keys())
  ]

  def note_to_dict(note):
    return {
        'start_time': note.start_time,
        'end_time': note.end_time,
        'pitch': note.pitch,
        'velocity': note.velocity,
        'program': note.program,
        'is_drum': note.is_drum
    }

  with tf.io.gfile.GFile(path, 'w') as f:
    for target, prediction in full_target_prediction_pairs:
      json_dict = {
          'id': target['ref_ns'].id,
          'est_notes':
              [note_to_dict(note) for note in prediction['est_ns'].notes]
      }
      json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder)
      f.write(json_str + '\n')