youtube-music-transcribe / mt3 /vocabularies.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model vocabulary."""
import dataclasses
import math
from typing import Callable, Optional, Sequence
from mt3 import event_codec
import note_seq
import seqio
import t5.data
import tensorflow as tf
DECODED_EOS_ID = -1
DECODED_INVALID_ID = -2
# defaults for vocabulary config
DEFAULT_STEPS_PER_SECOND = 100
DEFAULT_MAX_SHIFT_SECONDS = 10
DEFAULT_NUM_VELOCITY_BINS = 127
@dataclasses.dataclass
class VocabularyConfig:
"""Vocabulary configuration parameters."""
steps_per_second: int = DEFAULT_STEPS_PER_SECOND
max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS
num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS
@property
def abbrev_str(self):
s = ''
if self.steps_per_second != DEFAULT_STEPS_PER_SECOND:
s += 'ss%d' % self.steps_per_second
if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS:
s += 'ms%d' % self.max_shift_seconds
if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS:
s += 'vb%d' % self.num_velocity_bins
return s
def num_velocity_bins_from_codec(codec: event_codec.Codec):
"""Get number of velocity bins from event codec."""
lo, hi = codec.event_type_range('velocity')
return hi - lo
def velocity_to_bin(velocity, num_velocity_bins):
if velocity == 0:
return 0
else:
return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)
def bin_to_velocity(velocity_bin, num_velocity_bins):
if velocity_bin == 0:
return 0
else:
return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins)
def drop_programs(tokens, codec: event_codec.Codec):
"""Drops program change events from a token sequence."""
min_program_id, max_program_id = codec.event_type_range('program')
return tokens[(tokens < min_program_id) | (tokens > max_program_id)]
def programs_to_midi_classes(tokens, codec):
"""Modifies program events to be the first program in the MIDI class."""
min_program_id, max_program_id = codec.event_type_range('program')
is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
return tf.where(
is_program,
min_program_id + 8 * ((tokens - min_program_id) // 8),
tokens)
@dataclasses.dataclass
class ProgramGranularity:
# both tokens_map_fn and program_map_fn should be idempotent
tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]]
program_map_fn: Callable[[int], int]
PROGRAM_GRANULARITIES = {
# "flat" granularity; drop program change tokens and set NoteSequence
# programs to zero
'flat': ProgramGranularity(
tokens_map_fn=drop_programs,
program_map_fn=lambda program: 0),
# map each program to the first program in its MIDI class
'midi_class': ProgramGranularity(
tokens_map_fn=programs_to_midi_classes,
program_map_fn=lambda program: 8 * (program // 8)),
# leave programs as is
'full': ProgramGranularity(
tokens_map_fn=lambda tokens, codec: tokens,
program_map_fn=lambda program: program)
}
def build_codec(vocab_config: VocabularyConfig):
"""Build event codec."""
event_ranges = [
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
note_seq.MAX_MIDI_PITCH),
# velocity bin 0 is used for note-off
event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins),
# used to indicate that a pitch is present at the beginning of a segment
# (only has an "off" event as when using ties all pitch events until the
# "tie" event belong to the tie section)
event_codec.EventRange('tie', 0, 0),
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
note_seq.MAX_MIDI_PROGRAM),
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
note_seq.MAX_MIDI_PITCH),
]
return event_codec.Codec(
max_shift_steps=(vocab_config.steps_per_second *
vocab_config.max_shift_seconds),
steps_per_second=vocab_config.steps_per_second,
event_ranges=event_ranges)
def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary:
return GenericTokenVocabulary(
codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS)
class GenericTokenVocabulary(seqio.Vocabulary):
"""Vocabulary with pass-through encoding of tokens."""
def __init__(self, regular_ids: int, extra_ids: int = 0):
# The special tokens: 0=PAD, 1=EOS, and 2=UNK
self._num_special_tokens = 3
self._num_regular_tokens = regular_ids
super().__init__(extra_ids=extra_ids)
@property
def eos_id(self) -> Optional[int]:
return 1
@property
def unk_id(self) -> Optional[int]:
return 2
@property
def _base_vocab_size(self) -> int:
"""Number of ids.
Returns:
an integer, the vocabulary size
"""
return self._num_special_tokens + self._num_regular_tokens
def _encode(self, token_ids: Sequence[int]) -> Sequence[int]:
"""Encode a list of tokens ids as a list of integers.
To keep the first few ids for special tokens, increase ids by the number
of special tokens.
Args:
token_ids: array of token ids.
Returns:
a list of integers (not terminated by EOS)
"""
encoded = []
for token_id in token_ids:
if not 0 <= token_id < self._num_regular_tokens:
raise ValueError(
f'token_id {token_id} does not fall within valid range of '
f'[0, {self._num_regular_tokens})')
encoded.append(token_id + self._num_special_tokens)
return encoded
def _decode(self, ids: Sequence[int]) -> Sequence[int]:
"""Decode a list of integers to a list of token ids.
The special tokens of PAD and UNK as well as extra_ids will be
replaced with DECODED_INVALID_ID in the output. If EOS is present, it will
be the final token in the decoded output and will be represented by
DECODED_EOS_ID.
Args:
ids: a list of integers
Returns:
a list of token ids.
"""
# convert all the extra ids to INVALID_ID
def _decode_id(encoded_id):
if encoded_id == self.eos_id:
return DECODED_EOS_ID
elif encoded_id < self._num_special_tokens:
return DECODED_INVALID_ID
elif encoded_id >= self._base_vocab_size:
return DECODED_INVALID_ID
else:
return encoded_id - self._num_special_tokens
ids = [_decode_id(int(i)) for i in ids]
return ids
def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
"""Encode a list of tokens to a tf.Tensor.
Args:
token_ids: array of audio token ids.
Returns:
a 1d tf.Tensor with dtype tf.int32
"""
with tf.control_dependencies(
[tf.debugging.assert_less(
token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
tf.debugging.assert_greater_equal(
token_ids, tf.cast(0, token_ids.dtype))
]):
tf_ids = token_ids + self._num_special_tokens
return tf_ids
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
"""Decode in TensorFlow.
The special tokens of PAD and UNK as well as extra_ids will be
replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
all following tokens in the decoded output and will be represented by
DECODED_EOS_ID.
Args:
ids: a 1d tf.Tensor with dtype tf.int32
Returns:
a 1d tf.Tensor with dtype tf.int32
"""
# Create a mask that is true from the first EOS position onward.
# First, create an array that is True whenever there is an EOS, then cumsum
# that array so that every position after and including the first True is
# >1, then cast back to bool for the final mask.
eos_and_after = tf.cumsum(
tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
eos_and_after = tf.cast(eos_and_after, tf.bool)
return tf.where(
eos_and_after,
DECODED_EOS_ID,
tf.where(
tf.logical_and(
tf.greater_equal(ids, self._num_special_tokens),
tf.less(ids, self._base_vocab_size)),
ids - self._num_special_tokens,
DECODED_INVALID_ID))
def __eq__(self, other):
their_extra_ids = other.extra_ids
their_num_regular_tokens = other._num_regular_tokens
return (self.extra_ids == their_extra_ids and
self._num_regular_tokens == their_num_regular_tokens)
def num_embeddings(vocabulary: GenericTokenVocabulary) -> int:
"""Vocabulary size as a multiple of 128 for TPU efficiency."""
return 128 * math.ceil(vocabulary.vocab_size / 128)