File size: 9,175 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Model vocabulary."""

import dataclasses
import math

from typing import Callable, Optional, Sequence
from mt3 import event_codec

import note_seq
import seqio
import t5.data
import tensorflow as tf


DECODED_EOS_ID = -1
DECODED_INVALID_ID = -2

# defaults for vocabulary config
DEFAULT_STEPS_PER_SECOND = 100
DEFAULT_MAX_SHIFT_SECONDS = 10
DEFAULT_NUM_VELOCITY_BINS = 127


@dataclasses.dataclass
class VocabularyConfig:
  """Vocabulary configuration parameters."""
  steps_per_second: int = DEFAULT_STEPS_PER_SECOND
  max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS
  num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS

  @property
  def abbrev_str(self):
    s = ''
    if self.steps_per_second != DEFAULT_STEPS_PER_SECOND:
      s += 'ss%d' % self.steps_per_second
    if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS:
      s += 'ms%d' % self.max_shift_seconds
    if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS:
      s += 'vb%d' % self.num_velocity_bins
    return s


def num_velocity_bins_from_codec(codec: event_codec.Codec):
  """Get number of velocity bins from event codec."""
  lo, hi = codec.event_type_range('velocity')
  return hi - lo


def velocity_to_bin(velocity, num_velocity_bins):
  if velocity == 0:
    return 0
  else:
    return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)


def bin_to_velocity(velocity_bin, num_velocity_bins):
  if velocity_bin == 0:
    return 0
  else:
    return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins)


def drop_programs(tokens, codec: event_codec.Codec):
  """Drops program change events from a token sequence."""
  min_program_id, max_program_id = codec.event_type_range('program')
  return tokens[(tokens < min_program_id) | (tokens > max_program_id)]


def programs_to_midi_classes(tokens, codec):
  """Modifies program events to be the first program in the MIDI class."""
  min_program_id, max_program_id = codec.event_type_range('program')
  is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
  return tf.where(
      is_program,
      min_program_id + 8 * ((tokens - min_program_id) // 8),
      tokens)


@dataclasses.dataclass
class ProgramGranularity:
  # both tokens_map_fn and program_map_fn should be idempotent
  tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]]
  program_map_fn: Callable[[int], int]


PROGRAM_GRANULARITIES = {
    # "flat" granularity; drop program change tokens and set NoteSequence
    # programs to zero
    'flat': ProgramGranularity(
        tokens_map_fn=drop_programs,
        program_map_fn=lambda program: 0),

    # map each program to the first program in its MIDI class
    'midi_class': ProgramGranularity(
        tokens_map_fn=programs_to_midi_classes,
        program_map_fn=lambda program: 8 * (program // 8)),

    # leave programs as is
    'full': ProgramGranularity(
        tokens_map_fn=lambda tokens, codec: tokens,
        program_map_fn=lambda program: program)
}


def build_codec(vocab_config: VocabularyConfig):
  """Build event codec."""
  event_ranges = [
      event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
                             note_seq.MAX_MIDI_PITCH),
      # velocity bin 0 is used for note-off
      event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins),
      # used to indicate that a pitch is present at the beginning of a segment
      # (only has an "off" event as when using ties all pitch events until the
      # "tie" event belong to the tie section)
      event_codec.EventRange('tie', 0, 0),
      event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
                             note_seq.MAX_MIDI_PROGRAM),
      event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
                             note_seq.MAX_MIDI_PITCH),
  ]

  return event_codec.Codec(
      max_shift_steps=(vocab_config.steps_per_second *
                       vocab_config.max_shift_seconds),
      steps_per_second=vocab_config.steps_per_second,
      event_ranges=event_ranges)


def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary:
  return GenericTokenVocabulary(
      codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS)


class GenericTokenVocabulary(seqio.Vocabulary):
  """Vocabulary with pass-through encoding of tokens."""

  def __init__(self, regular_ids: int, extra_ids: int = 0):
    # The special tokens: 0=PAD, 1=EOS, and 2=UNK
    self._num_special_tokens = 3
    self._num_regular_tokens = regular_ids
    super().__init__(extra_ids=extra_ids)

  @property
  def eos_id(self) -> Optional[int]:
    return 1

  @property
  def unk_id(self) -> Optional[int]:
    return 2

  @property
  def _base_vocab_size(self) -> int:
    """Number of ids.

    Returns:
      an integer, the vocabulary size
    """
    return self._num_special_tokens + self._num_regular_tokens

  def _encode(self, token_ids: Sequence[int]) -> Sequence[int]:
    """Encode a list of tokens ids as a list of integers.

    To keep the first few ids for special tokens, increase ids by the number
    of special tokens.

    Args:
      token_ids: array of token ids.

    Returns:
      a list of integers (not terminated by EOS)
    """
    encoded = []
    for token_id in token_ids:
      if not 0 <= token_id < self._num_regular_tokens:
        raise ValueError(
            f'token_id {token_id} does not fall within valid range of '
            f'[0, {self._num_regular_tokens})')
      encoded.append(token_id + self._num_special_tokens)

    return encoded

  def _decode(self, ids: Sequence[int]) -> Sequence[int]:
    """Decode a list of integers to a list of token ids.

    The special tokens of PAD and UNK as well as extra_ids will be
    replaced with DECODED_INVALID_ID in the output. If EOS is present, it will
    be the final token in the decoded output and will be represented by
    DECODED_EOS_ID.

    Args:
      ids: a list of integers

    Returns:
      a list of token ids.
    """
    # convert all the extra ids  to INVALID_ID
    def _decode_id(encoded_id):
      if encoded_id == self.eos_id:
        return DECODED_EOS_ID
      elif encoded_id < self._num_special_tokens:
        return DECODED_INVALID_ID
      elif encoded_id >= self._base_vocab_size:
        return DECODED_INVALID_ID
      else:
        return encoded_id - self._num_special_tokens
    ids = [_decode_id(int(i)) for i in ids]
    return ids

  def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
    """Encode a list of tokens to a tf.Tensor.

    Args:
      token_ids: array of audio token ids.

    Returns:
      a 1d tf.Tensor with dtype tf.int32
    """
    with tf.control_dependencies(
        [tf.debugging.assert_less(
            token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
         tf.debugging.assert_greater_equal(
             token_ids, tf.cast(0, token_ids.dtype))
         ]):
      tf_ids = token_ids + self._num_special_tokens
    return tf_ids

  def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
    """Decode in TensorFlow.

    The special tokens of PAD and UNK as well as extra_ids will be
    replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
    all following tokens in the decoded output and will be represented by
    DECODED_EOS_ID.

    Args:
      ids: a 1d tf.Tensor with dtype tf.int32

    Returns:
      a 1d tf.Tensor with dtype tf.int32
    """
    # Create a mask that is true from the first EOS position onward.
    # First, create an array that is True whenever there is an EOS, then cumsum
    # that array so that every position after and including the first True is
    # >1, then cast back to bool for the final mask.
    eos_and_after = tf.cumsum(
        tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
    eos_and_after = tf.cast(eos_and_after, tf.bool)

    return tf.where(
        eos_and_after,
        DECODED_EOS_ID,
        tf.where(
            tf.logical_and(
                tf.greater_equal(ids, self._num_special_tokens),
                tf.less(ids, self._base_vocab_size)),
            ids - self._num_special_tokens,
            DECODED_INVALID_ID))

  def __eq__(self, other):
    their_extra_ids = other.extra_ids
    their_num_regular_tokens = other._num_regular_tokens
    return (self.extra_ids == their_extra_ids and
            self._num_regular_tokens == their_num_regular_tokens)


def num_embeddings(vocabulary: GenericTokenVocabulary) -> int:
  """Vocabulary size as a multiple of 128 for TPU efficiency."""
  return 128 * math.ceil(vocabulary.vocab_size / 128)