juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transcription task definitions."""
import functools
from typing import Optional, Sequence
from mt3 import datasets
from mt3 import event_codec
from mt3 import metrics
from mt3 import mixing
from mt3 import preprocessors
from mt3 import run_length_encoding
from mt3 import spectrograms
from mt3 import vocabularies
import note_seq
import numpy as np
import seqio
import t5
import tensorflow as tf
# Split audio frame sequences into this length before the cache placeholder.
MAX_NUM_CACHED_FRAMES = 2000
seqio.add_global_cache_dirs(['gs://mt3/data/cache_tasks/'])
def construct_task_name(
task_prefix: str,
spectrogram_config=spectrograms.SpectrogramConfig(),
vocab_config=vocabularies.VocabularyConfig(),
task_suffix: Optional[str] = None
) -> str:
"""Construct task name from prefix, config, and optional suffix."""
fields = [task_prefix]
if spectrogram_config.abbrev_str:
fields.append(spectrogram_config.abbrev_str)
if vocab_config.abbrev_str:
fields.append(vocab_config.abbrev_str)
if task_suffix:
fields.append(task_suffix)
return '_'.join(fields)
def trim_eos(tokens: Sequence[int]) -> np.ndarray:
"""If EOS is present, remove it and everything after."""
tokens = np.array(tokens, np.int32)
if vocabularies.DECODED_EOS_ID in tokens:
tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
return tokens
def postprocess(tokens, example, is_target, codec):
"""Transcription postprocessing function."""
tokens = trim_eos(tokens)
if is_target:
return {
'unique_id': example['unique_id'][0],
'ref_ns': (note_seq.NoteSequence.FromString(example['sequence'][0])
if example['sequence'][0] else None),
'ref_tokens': tokens,
}
start_time = example['input_times'][0]
# Round down to nearest symbolic token step.
start_time -= start_time % (1 / codec.steps_per_second)
return {
'unique_id': example['unique_id'][0],
'raw_inputs': example['raw_inputs'],
'est_tokens': tokens,
'start_time': start_time
}
def add_transcription_task_to_registry(
dataset_config: datasets.DatasetConfig,
spectrogram_config: spectrograms.SpectrogramConfig,
vocab_config: vocabularies.VocabularyConfig,
tokenize_fn, # TODO(iansimon): add type signature
onsets_only: bool,
include_ties: bool,
skip_too_long: bool = False
) -> None:
"""Add note transcription task to seqio.TaskRegistry."""
codec = vocabularies.build_codec(vocab_config)
vocabulary = vocabularies.vocabulary_from_codec(codec)
output_features = {
'targets': seqio.Feature(vocabulary=vocabulary),
'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2)
}
task_name = 'onsets' if onsets_only else 'notes'
if include_ties:
task_name += '_ties'
task_prefix = f'{dataset_config.name}_{task_name}'
train_task_name = construct_task_name(
task_prefix=task_prefix,
spectrogram_config=spectrogram_config,
vocab_config=vocab_config,
task_suffix='train')
mixture_task_names = []
tie_token = codec.encode_event(event_codec.Event('tie', 0))
track_specs = (dataset_config.track_specs
if dataset_config.track_specs else None)
# Add transcription training task.
seqio.TaskRegistry.add(
train_task_name,
source=seqio.TFExampleDataSource(
split_to_filepattern={
'train': dataset_config.paths[dataset_config.train_split],
'eval': dataset_config.paths[dataset_config.train_eval_split]
},
feature_description=dataset_config.features),
output_features=output_features,
preprocessors=[
functools.partial(
tokenize_fn,
spectrogram_config=spectrogram_config, codec=codec,
is_training_data=True, onsets_only=onsets_only,
include_ties=include_ties),
functools.partial(
t5.data.preprocessors.split_tokens,
max_tokens_per_segment=MAX_NUM_CACHED_FRAMES,
feature_key='inputs',
additional_feature_keys=[
'input_event_start_indices', 'input_event_end_indices',
'input_state_event_indices'
],
passthrough_feature_keys=['targets', 'state_events']),
seqio.CacheDatasetPlaceholder(),
functools.partial(
t5.data.preprocessors.select_random_chunk,
feature_key='inputs',
additional_feature_keys=[
'input_event_start_indices', 'input_event_end_indices',
'input_state_event_indices'
],
passthrough_feature_keys=['targets', 'state_events'],
uniform_random_start=True),
functools.partial(
run_length_encoding.extract_target_sequence_with_indices,
state_events_end_token=tie_token if include_ties else None),
functools.partial(preprocessors.map_midi_programs, codec=codec),
run_length_encoding.run_length_encode_shifts_fn(
codec,
feature_key='targets'),
functools.partial(
mixing.mix_transcription_examples,
codec=codec,
targets_feature_keys=['targets']),
run_length_encoding.remove_redundant_state_changes_fn(
feature_key='targets', codec=codec,
state_change_event_types=['velocity', 'program']),
functools.partial(
preprocessors.compute_spectrograms,
spectrogram_config=spectrogram_config),
functools.partial(preprocessors.handle_too_long, skip=skip_too_long),
functools.partial(
seqio.preprocessors.tokenize_and_append_eos,
copy_pretokenized=False)
],
postprocess_fn=None,
metric_fns=[],
)
# Add transcription eval tasks.
for split in dataset_config.infer_eval_splits:
eval_task_name = construct_task_name(
task_prefix=task_prefix,
spectrogram_config=spectrogram_config,
vocab_config=vocab_config,
task_suffix=split.suffix)
if split.include_in_mixture:
mixture_task_names.append(eval_task_name)
seqio.TaskRegistry.add(
eval_task_name,
source=seqio.TFExampleDataSource(
split_to_filepattern={'eval': dataset_config.paths[split.name]},
feature_description=dataset_config.features),
output_features=output_features,
preprocessors=[
functools.partial(
tokenize_fn,
spectrogram_config=spectrogram_config, codec=codec,
is_training_data='train' in split.name, onsets_only=onsets_only,
include_ties=include_ties),
seqio.CacheDatasetPlaceholder(),
preprocessors.add_unique_id,
preprocessors.pad_notesequence_array,
functools.partial(
t5.data.preprocessors.split_tokens_to_inputs_length,
feature_key='inputs',
additional_feature_keys=['input_times', 'sequence'],
passthrough_feature_keys=['unique_id']),
# Add dummy targets as they are dropped during the above split to
# avoid memory blowups, but expected to be present by seqio; the
# evaluation metrics currently only use the target NoteSequence.
preprocessors.add_dummy_targets,
functools.partial(
preprocessors.compute_spectrograms,
spectrogram_config=spectrogram_config),
functools.partial(preprocessors.handle_too_long, skip=False),
functools.partial(
seqio.preprocessors.tokenize_and_append_eos,
copy_pretokenized=False)
],
postprocess_fn=functools.partial(postprocess, codec=codec),
metric_fns=[
functools.partial(
metrics.transcription_metrics,
codec=codec,
spectrogram_config=spectrogram_config,
onsets_only=onsets_only,
use_ties=include_ties,
track_specs=track_specs)
],
)
seqio.MixtureRegistry.add(
construct_task_name(
task_prefix=task_prefix, spectrogram_config=spectrogram_config,
vocab_config=vocab_config, task_suffix='eval'),
mixture_task_names,
default_rate=1)
# Just use default spectrogram config.
SPECTROGRAM_CONFIG = spectrograms.SpectrogramConfig()
# Create two vocabulary configs, one default and one with only on-off velocity.
VOCAB_CONFIG_FULL = vocabularies.VocabularyConfig()
VOCAB_CONFIG_NOVELOCITY = vocabularies.VocabularyConfig(num_velocity_bins=1)
# Transcribe MAESTRO v1.
add_transcription_task_to_registry(
dataset_config=datasets.MAESTROV1_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_FULL,
tokenize_fn=functools.partial(
preprocessors.tokenize_transcription_example,
audio_is_samples=False,
id_feature_key='id'),
onsets_only=False,
include_ties=False)
# Transcribe MAESTRO v3.
add_transcription_task_to_registry(
dataset_config=datasets.MAESTROV3_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_FULL,
tokenize_fn=functools.partial(
preprocessors.tokenize_transcription_example,
audio_is_samples=False,
id_feature_key='id'),
onsets_only=False,
include_ties=False)
# Transcribe MAESTRO v3 without velocities, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.MAESTROV3_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_transcription_example,
audio_is_samples=False,
id_feature_key='id'),
onsets_only=False,
include_ties=True)
# Transcribe GuitarSet, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.GUITARSET_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=preprocessors.tokenize_guitarset_example,
onsets_only=False,
include_ties=True)
# Transcribe URMP mixes, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.URMP_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_example_with_program_lookup,
inst_name_to_program_fn=preprocessors.urmp_instrument_to_program,
id_feature_key='id'),
onsets_only=False,
include_ties=True)
# Transcribe MusicNet, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.MUSICNET_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_transcription_example,
audio_is_samples=True,
id_feature_key='id'),
onsets_only=False,
include_ties=True)
# Transcribe MusicNetEM, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.MUSICNET_EM_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_transcription_example,
audio_is_samples=True,
id_feature_key='id'),
onsets_only=False,
include_ties=True)
# Transcribe Cerberus4 (piano-guitar-bass-drums quartets), with ties.
add_transcription_task_to_registry(
dataset_config=datasets.CERBERUS4_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_slakh_example,
track_specs=datasets.CERBERUS4_CONFIG.track_specs,
ignore_pitch_bends=True),
onsets_only=False,
include_ties=True)
# Transcribe 10 random sub-mixes of each song from Slakh, with ties.
add_transcription_task_to_registry(
dataset_config=datasets.SLAKH_CONFIG,
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
tokenize_fn=functools.partial(
preprocessors.tokenize_slakh_example,
track_specs=None,
ignore_pitch_bends=True),
onsets_only=False,
include_ties=True)
# Construct task names to include in transcription mixture.
MIXTURE_DATASET_NAMES = [
'maestrov3', 'guitarset', 'urmp', 'musicnet_em', 'cerberus4', 'slakh'
]
MIXTURE_TRAIN_TASK_NAMES = []
MIXTURE_EVAL_TASK_NAMES = []
MIXTURE_TEST_TASK_NAMES = []
for dataset_name in MIXTURE_DATASET_NAMES:
MIXTURE_TRAIN_TASK_NAMES.append(
construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
task_suffix='train'))
MIXTURE_EVAL_TASK_NAMES.append(
construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
task_suffix='validation'))
MIXING_TEMPERATURE = 10 / 3
# Add the mixture of all transcription tasks, with ties.
seqio.MixtureRegistry.add(
construct_task_name(
task_prefix='mega_notes_ties',
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
task_suffix='train'),
MIXTURE_TRAIN_TASK_NAMES,
default_rate=functools.partial(
seqio.mixing_rate_num_examples,
temperature=MIXING_TEMPERATURE))
seqio.MixtureRegistry.add(
construct_task_name(
task_prefix='mega_notes_ties',
spectrogram_config=SPECTROGRAM_CONFIG,
vocab_config=VOCAB_CONFIG_NOVELOCITY,
task_suffix='eval'),
MIXTURE_EVAL_TASK_NAMES,
default_rate=functools.partial(
seqio.mixing_rate_num_examples,
temperature=MIXING_TEMPERATURE))