File size: 3,396 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for mixing (in the audio sense) multiple transcription examples."""

from typing import Callable, Optional, Sequence

import gin

from mt3 import event_codec
from mt3 import run_length_encoding

import numpy as np
import seqio
import tensorflow as tf


@gin.configurable
def mix_transcription_examples(
    ds: tf.data.Dataset,
    sequence_length: seqio.preprocessors.SequenceLengthType,
    output_features: seqio.preprocessors.OutputFeaturesType,
    codec: event_codec.Codec,
    inputs_feature_key: str = 'inputs',
    targets_feature_keys: Sequence[str] = ('targets',),
    max_examples_per_mix: Optional[int] = None,
    shuffle_buffer_size: int = seqio.SHUFFLE_BUFFER_SIZE
) -> Callable[..., tf.data.Dataset]:
  """Preprocessor that mixes together "batches" of transcription examples.

  Args:
    ds: Dataset of individual transcription examples, each of which should
        have an 'inputs' field containing 1D audio samples (currently only
        audio encoders that use raw samples as an intermediate representation
        are supported), and a 'targets' field containing run-length encoded
        note events.
    sequence_length: Dictionary mapping feature key to length.
    output_features: Dictionary mapping feature key to spec.
    codec: An event_codec.Codec used to interpret the target events.
    inputs_feature_key: Feature key for inputs which will be mixed as audio.
    targets_feature_keys: List of feature keys for targets, each of which will
        be merged (separately) as run-length encoded note events.
    max_examples_per_mix: Maximum number of individual examples to mix together.
    shuffle_buffer_size: Size of shuffle buffer to use for shuffle prior to
        mixing.

  Returns:
    Dataset containing mixed examples.
  """
  if max_examples_per_mix is None:
    return ds

  # TODO(iansimon): is there a way to use seqio's seed?
  ds = tf.data.Dataset.sample_from_datasets([
      ds.shuffle(
          buffer_size=shuffle_buffer_size // max_examples_per_mix
      ).padded_batch(batch_size=i) for i in range(1, max_examples_per_mix + 1)
  ])

  def mix_inputs(ex):
    samples = tf.reduce_sum(ex[inputs_feature_key], axis=0)
    norm = tf.linalg.norm(samples, ord=np.inf)
    ex[inputs_feature_key] = tf.math.divide_no_nan(samples, norm)
    return ex
  ds = ds.map(mix_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  max_tokens = sequence_length['targets']
  if output_features['targets'].add_eos:
    # Leave room to insert an EOS token.
    max_tokens -= 1

  def mix_targets(ex):
    for k in targets_feature_keys:
      ex[k] = run_length_encoding.merge_run_length_encoded_targets(
          targets=ex[k],
          codec=codec)
    return ex
  ds = ds.map(mix_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return ds