Spaces:
Build error
Build error
# Copyright 2022 The MT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Transcription preprocessors.""" | |
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple | |
from absl import logging | |
import gin | |
from immutabledict import immutabledict | |
import librosa | |
from mt3 import event_codec | |
from mt3 import note_sequences | |
from mt3 import run_length_encoding | |
from mt3 import spectrograms | |
from mt3 import vocabularies | |
import note_seq | |
import numpy as np | |
import seqio | |
import tensorflow as tf | |
def add_unique_id(ds: tf.data.Dataset) -> tf.data.Dataset: | |
"""Add unique integer ID to each example in a dataset.""" | |
def add_id_field(i, ex): | |
ex['unique_id'] = [i] | |
return ex | |
return ds.enumerate().map( | |
add_id_field, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def pad_notesequence_array(ex): | |
"""Pad the NoteSequence array so that it can later be "split".""" | |
ex['sequence'] = tf.pad(tf.expand_dims(ex['sequence'], 0), | |
[[0, len(ex['input_times']) - 1]]) | |
return ex | |
def add_dummy_targets(ex): | |
"""Add dummy targets; used in eval when targets are not actually used.""" | |
ex['targets'] = np.array([], dtype=np.int32) | |
return ex | |
def _audio_to_frames( | |
samples: Sequence[float], | |
spectrogram_config: spectrograms.SpectrogramConfig, | |
) -> Tuple[Sequence[Sequence[int]], np.ndarray]: | |
"""Convert audio samples to non-overlapping frames and frame times.""" | |
frame_size = spectrogram_config.hop_width | |
logging.info('Padding %d samples to multiple of %d', len(samples), frame_size) | |
samples = np.pad(samples, | |
[0, frame_size - len(samples) % frame_size], | |
mode='constant') | |
frames = spectrograms.split_audio(samples, spectrogram_config) | |
num_frames = len(samples) // frame_size | |
logging.info('Encoded %d samples to %d frames (%d samples each)', | |
len(samples), num_frames, frame_size) | |
times = np.arange(num_frames) / spectrogram_config.frames_per_second | |
return frames, times | |
def _include_inputs(ds, input_record, fields_to_omit=('audio',)): | |
"""Include fields from input record (other than audio) in dataset records.""" | |
def include_inputs_fn(output_record): | |
for key in set(input_record.keys()) - set(output_record.keys()): | |
output_record[key] = input_record[key] | |
for key in fields_to_omit: | |
del output_record[key] | |
return output_record | |
return ds.map(include_inputs_fn, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def tokenize_transcription_example( | |
ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig, | |
codec: event_codec.Codec, is_training_data: bool, | |
onsets_only: bool, include_ties: bool, audio_is_samples: bool, | |
id_feature_key: Optional[str] = None | |
) -> tf.data.Dataset: | |
"""Tokenize a note transcription example for run-length encoding. | |
Outputs include: | |
inputs: audio sample frames, num_frames-by-frame_size | |
input_time: timestamp for each frame | |
targets: symbolic sequence of note-related events | |
input_event_start_indices: start target index for every input index | |
input_event_end_indices: end target index for every input index | |
Args: | |
ds: Input dataset. | |
spectrogram_config: Spectrogram configuration. | |
codec: Event vocabulary codec. | |
is_training_data: Unused. | |
onsets_only: If True, include only onset events (not offset, velocity, or | |
program). | |
include_ties: If True, also write state events containing active notes to | |
support a "tie" section after run-length encoding. | |
audio_is_samples: If True, audio is floating-point samples instead of | |
serialized WAV. | |
id_feature_key: If not None, replace sequence ID with specified key field | |
from the dataset. | |
Returns: | |
Dataset with the outputs described above. | |
""" | |
del is_training_data | |
if onsets_only and include_ties: | |
raise ValueError('Ties not supported when only modeling onsets.') | |
def tokenize(sequence, audio, sample_rate, example_id=None): | |
ns = note_seq.NoteSequence.FromString(sequence) | |
note_sequences.validate_note_sequence(ns) | |
if example_id is not None: | |
ns.id = example_id | |
if audio_is_samples: | |
samples = audio | |
if sample_rate != spectrogram_config.sample_rate: | |
samples = librosa.resample( | |
samples, sample_rate, spectrogram_config.sample_rate) | |
else: | |
samples = note_seq.audio_io.wav_data_to_samples_librosa( | |
audio, sample_rate=spectrogram_config.sample_rate) | |
logging.info('Got samples for %s::%s with length %d', | |
ns.id, ns.filename, len(samples)) | |
frames, frame_times = _audio_to_frames(samples, spectrogram_config) | |
if onsets_only: | |
times, values = note_sequences.note_sequence_to_onsets(ns) | |
else: | |
ns = note_seq.apply_sustain_control_changes(ns) | |
times, values = ( | |
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) | |
# The original NoteSequence can have a lot of control changes we don't need; | |
# delete them. | |
del ns.control_changes[:] | |
(events, event_start_indices, event_end_indices, | |
state_events, state_event_indices) = ( | |
run_length_encoding.encode_and_index_events( | |
state=note_sequences.NoteEncodingState() if include_ties else None, | |
event_times=times, | |
event_values=values, | |
encode_event_fn=note_sequences.note_event_data_to_events, | |
codec=codec, | |
frame_times=frame_times, | |
encoding_state_to_events_fn=( | |
note_sequences.note_encoding_state_to_events | |
if include_ties else None))) | |
yield { | |
'inputs': frames, | |
'input_times': frame_times, | |
'targets': events, | |
'input_event_start_indices': event_start_indices, | |
'input_event_end_indices': event_end_indices, | |
'state_events': state_events, | |
'input_state_event_indices': state_event_indices, | |
'sequence': ns.SerializeToString() | |
} | |
def process_record(input_record): | |
if audio_is_samples and 'sample_rate' not in input_record: | |
raise ValueError('Must provide sample rate when audio is samples.') | |
args = [ | |
input_record['sequence'], | |
input_record['audio'], | |
input_record['sample_rate'] if 'sample_rate' in input_record else 0 | |
] | |
if id_feature_key is not None: | |
args.append(input_record[id_feature_key]) | |
ds = tf.data.Dataset.from_generator( | |
tokenize, | |
output_signature={ | |
'inputs': | |
tf.TensorSpec( | |
shape=(None, spectrogram_config.hop_width), | |
dtype=tf.float32), | |
'input_times': | |
tf.TensorSpec(shape=(None,), dtype=tf.float32), | |
'targets': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_start_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_end_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'state_events': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_state_event_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'sequence': | |
tf.TensorSpec(shape=(), dtype=tf.string) | |
}, | |
args=args) | |
ds = _include_inputs(ds, input_record) | |
return ds | |
tokenized_records = ds.flat_map(process_record) | |
return tokenized_records | |
def tokenize_guitarset_example( | |
ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig, | |
codec: event_codec.Codec, is_training_data: bool, | |
onsets_only: bool, include_ties: bool | |
) -> tf.data.Dataset: | |
"""Tokenize a GuitarSet transcription example.""" | |
def _preprocess_example(ex, name): | |
assert 'inst_names' not in ex, 'Key `inst_names` is already populated.' | |
ex['inst_names'] = [name] | |
ex['instrument_sequences'] = [ex.pop('sequence')] | |
return ex | |
ds = ds.map( | |
lambda x: _preprocess_example(x, 'Clean Guitar'), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
ds = tokenize_example_with_program_lookup( | |
ds, | |
spectrogram_config=spectrogram_config, | |
codec=codec, | |
is_training_data=is_training_data, | |
inst_name_to_program_fn=guitarset_instrument_to_program, | |
onsets_only=onsets_only, | |
include_ties=include_ties, | |
id_feature_key='id') | |
return ds | |
def guitarset_instrument_to_program(instrument: str) -> int: | |
"""GuitarSet is all guitar, return the first MIDI guitar program.""" | |
if instrument == 'Clean Guitar': | |
return 24 | |
else: | |
raise ValueError('Unknown GuitarSet instrument: %s' % instrument) | |
def tokenize_example_with_program_lookup( | |
ds: tf.data.Dataset, | |
spectrogram_config: spectrograms.SpectrogramConfig, | |
codec: event_codec.Codec, | |
is_training_data: bool, | |
onsets_only: bool, | |
include_ties: bool, | |
inst_name_to_program_fn: Callable[[str], int], | |
id_feature_key: Optional[str] = None | |
) -> tf.data.Dataset: | |
"""Tokenize an example, optionally looking up and assigning program numbers. | |
This can be used by any dataset where a mapping function can be used to | |
map from the inst_names feature to a set of program numbers. | |
Args: | |
ds: Input dataset. | |
spectrogram_config: Spectrogram configuration. | |
codec: Event vocabulary codec. | |
is_training_data: Unused. | |
onsets_only: If True, include only onset events (not offset & velocity). | |
include_ties: If True, include tie events. | |
inst_name_to_program_fn: A function used to map the instrument names | |
in the `inst_names` feature of each example to a MIDI program number. | |
id_feature_key: If not None, replace sequence ID with specified key field | |
from the dataset. | |
Returns: | |
Dataset with the outputs described above. | |
""" | |
del is_training_data | |
def tokenize(sequences, inst_names, audio, example_id=None): | |
# Add all the notes from the tracks to a single NoteSequence. | |
ns = note_seq.NoteSequence(ticks_per_quarter=220) | |
tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences] | |
assert len(tracks) == len(inst_names) | |
for track, inst_name in zip(tracks, inst_names): | |
program = inst_name_to_program_fn( | |
inst_name.decode()) | |
# Note that there are no pitch bends in URMP data; the below block will | |
# raise PitchBendError if one is encountered. | |
add_track_to_notesequence(ns, track, program=program, is_drum=False, | |
ignore_pitch_bends=False) | |
note_sequences.assign_instruments(ns) | |
note_sequences.validate_note_sequence(ns) | |
if example_id is not None: | |
ns.id = example_id | |
samples = note_seq.audio_io.wav_data_to_samples_librosa( | |
audio, sample_rate=spectrogram_config.sample_rate) | |
logging.info('Got samples for %s::%s with length %d', | |
ns.id, ns.filename, len(samples)) | |
frames, frame_times = _audio_to_frames(samples, spectrogram_config) | |
if onsets_only: | |
times, values = note_sequences.note_sequence_to_onsets(ns) | |
else: | |
times, values = ( | |
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) | |
# The original NoteSequence can have a lot of control changes we don't need; | |
# delete them. | |
del ns.control_changes[:] | |
(events, event_start_indices, event_end_indices, | |
state_events, state_event_indices) = ( | |
run_length_encoding.encode_and_index_events( | |
state=note_sequences.NoteEncodingState() if include_ties else None, | |
event_times=times, | |
event_values=values, | |
encode_event_fn=note_sequences.note_event_data_to_events, | |
codec=codec, | |
frame_times=frame_times, | |
encoding_state_to_events_fn=( | |
note_sequences.note_encoding_state_to_events | |
if include_ties else None))) | |
yield { | |
'inputs': frames, | |
'input_times': frame_times, | |
'targets': events, | |
'input_event_start_indices': event_start_indices, | |
'input_event_end_indices': event_end_indices, | |
'state_events': state_events, | |
'input_state_event_indices': state_event_indices, | |
'sequence': ns.SerializeToString() | |
} | |
def process_record(input_record): | |
args = [ | |
input_record['instrument_sequences'], | |
input_record['inst_names'], | |
input_record['audio'], | |
] | |
if id_feature_key is not None: | |
args.append(input_record[id_feature_key]) | |
ds = tf.data.Dataset.from_generator( | |
tokenize, | |
output_signature={ | |
'inputs': | |
tf.TensorSpec( | |
shape=(None, spectrogram_config.hop_width), | |
dtype=tf.float32), | |
'input_times': | |
tf.TensorSpec(shape=(None,), dtype=tf.float32), | |
'targets': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_start_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_end_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'state_events': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_state_event_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'sequence': | |
tf.TensorSpec(shape=(), dtype=tf.string) | |
}, | |
args=args) | |
ds = _include_inputs(ds, input_record) | |
return ds | |
tokenized_records = ds.flat_map(process_record) | |
return tokenized_records | |
_URMP_INSTRUMENT_PROGRAMS = immutabledict({ | |
'vn': 40, # violin | |
'va': 41, # viola | |
'vc': 42, # cello | |
'db': 43, # double bass | |
'tpt': 56, # trumpet | |
'tbn': 57, # trombone | |
'tba': 58, # tuba | |
'hn': 60, # French horn | |
'sax': 64, # saxophone | |
'ob': 68, # oboe | |
'bn': 70, # bassoon | |
'cl': 71, # clarinet | |
'fl': 73 # flute | |
}) | |
def urmp_instrument_to_program(urmp_instrument: str) -> int: | |
"""Fetch the program number associated with a given URMP instrument code.""" | |
if urmp_instrument not in _URMP_INSTRUMENT_PROGRAMS: | |
raise ValueError('unknown URMP instrument: %s' % urmp_instrument) | |
return _URMP_INSTRUMENT_PROGRAMS[urmp_instrument] | |
_SLAKH_CLASS_PROGRAMS = immutabledict({ | |
'Acoustic Piano': 0, | |
'Electric Piano': 4, | |
'Chromatic Percussion': 8, | |
'Organ': 16, | |
'Acoustic Guitar': 24, | |
'Clean Electric Guitar': 26, | |
'Distorted Electric Guitar': 29, | |
'Acoustic Bass': 32, | |
'Electric Bass': 33, | |
'Violin': 40, | |
'Viola': 41, | |
'Cello': 42, | |
'Contrabass': 43, | |
'Orchestral Harp': 46, | |
'Timpani': 47, | |
'String Ensemble': 48, | |
'Synth Strings': 50, | |
'Choir and Voice': 52, | |
'Orchestral Hit': 55, | |
'Trumpet': 56, | |
'Trombone': 57, | |
'Tuba': 58, | |
'French Horn': 60, | |
'Brass Section': 61, | |
'Soprano/Alto Sax': 64, | |
'Tenor Sax': 66, | |
'Baritone Sax': 67, | |
'Oboe': 68, | |
'English Horn': 69, | |
'Bassoon': 70, | |
'Clarinet': 71, | |
'Pipe': 73, | |
'Synth Lead': 80, | |
'Synth Pad': 88 | |
}) | |
def slakh_class_to_program_and_is_drum(slakh_class: str) -> Tuple[int, bool]: | |
"""Map Slakh class string to program number and boolean indicating drums.""" | |
if slakh_class == 'Drums': | |
return 0, True | |
elif slakh_class not in _SLAKH_CLASS_PROGRAMS: | |
raise ValueError('unknown Slakh class: %s' % slakh_class) | |
else: | |
return _SLAKH_CLASS_PROGRAMS[slakh_class], False | |
class PitchBendError(Exception): | |
pass | |
def add_track_to_notesequence(ns: note_seq.NoteSequence, | |
track: note_seq.NoteSequence, | |
program: int, is_drum: bool, | |
ignore_pitch_bends: bool): | |
"""Add a track to a NoteSequence.""" | |
if track.pitch_bends and not ignore_pitch_bends: | |
raise PitchBendError | |
track_sus = note_seq.apply_sustain_control_changes(track) | |
for note in track_sus.notes: | |
note.program = program | |
note.is_drum = is_drum | |
ns.notes.extend([note]) | |
ns.total_time = max(ns.total_time, note.end_time) | |
def tokenize_slakh_example( | |
ds: tf.data.Dataset, | |
spectrogram_config: spectrograms.SpectrogramConfig, | |
codec: event_codec.Codec, | |
is_training_data: bool, | |
onsets_only: bool, | |
include_ties: bool, | |
track_specs: Optional[Sequence[note_sequences.TrackSpec]], | |
ignore_pitch_bends: bool | |
) -> tf.data.Dataset: | |
"""Tokenize a Slakh multitrack note transcription example.""" | |
def tokenize(sequences, samples, sample_rate, inst_names, example_id): | |
if sample_rate != spectrogram_config.sample_rate: | |
samples = librosa.resample( | |
samples, sample_rate, spectrogram_config.sample_rate) | |
frames, frame_times = _audio_to_frames(samples, spectrogram_config) | |
# Add all the notes from the tracks to a single NoteSequence. | |
ns = note_seq.NoteSequence(ticks_per_quarter=220) | |
tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences] | |
assert len(tracks) == len(inst_names) | |
if track_specs: | |
# Specific tracks expected. | |
assert len(tracks) == len(track_specs) | |
for track, spec, inst_name in zip(tracks, track_specs, inst_names): | |
# Make sure the instrument name matches what we expect. | |
assert inst_name.decode() == spec.name | |
try: | |
add_track_to_notesequence(ns, track, | |
program=spec.program, is_drum=spec.is_drum, | |
ignore_pitch_bends=ignore_pitch_bends) | |
except PitchBendError: | |
# TODO(iansimon): is there a way to count these? | |
return | |
else: | |
for track, inst_name in zip(tracks, inst_names): | |
# Instrument name should be Slakh class. | |
program, is_drum = slakh_class_to_program_and_is_drum( | |
inst_name.decode()) | |
try: | |
add_track_to_notesequence(ns, track, program=program, is_drum=is_drum, | |
ignore_pitch_bends=ignore_pitch_bends) | |
except PitchBendError: | |
# TODO(iansimon): is there a way to count these? | |
return | |
note_sequences.assign_instruments(ns) | |
note_sequences.validate_note_sequence(ns) | |
if is_training_data: | |
# Trim overlapping notes in training (as our event vocabulary cannot | |
# represent them), but preserve original NoteSequence for eval. | |
ns = note_sequences.trim_overlapping_notes(ns) | |
ns.id = example_id | |
if onsets_only: | |
times, values = note_sequences.note_sequence_to_onsets(ns) | |
else: | |
times, values = ( | |
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) | |
(events, event_start_indices, event_end_indices, | |
state_events, state_event_indices) = ( | |
run_length_encoding.encode_and_index_events( | |
state=note_sequences.NoteEncodingState() if include_ties else None, | |
event_times=times, | |
event_values=values, | |
encode_event_fn=note_sequences.note_event_data_to_events, | |
codec=codec, | |
frame_times=frame_times, | |
encoding_state_to_events_fn=( | |
note_sequences.note_encoding_state_to_events | |
if include_ties else None))) | |
yield { | |
'inputs': frames, | |
'input_times': frame_times, | |
'targets': events, | |
'input_event_start_indices': event_start_indices, | |
'input_event_end_indices': event_end_indices, | |
'state_events': state_events, | |
'input_state_event_indices': state_event_indices, | |
'sequence': ns.SerializeToString() | |
} | |
def process_record(input_record): | |
ds = tf.data.Dataset.from_generator( | |
tokenize, | |
output_signature={ | |
'inputs': | |
tf.TensorSpec( | |
shape=(None, spectrogram_config.hop_width), | |
dtype=tf.float32), | |
'input_times': | |
tf.TensorSpec(shape=(None,), dtype=tf.float32), | |
'targets': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_start_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_event_end_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'state_events': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'input_state_event_indices': | |
tf.TensorSpec(shape=(None,), dtype=tf.int32), | |
'sequence': | |
tf.TensorSpec(shape=(), dtype=tf.string) | |
}, | |
args=[ | |
input_record['note_sequences'], input_record['mix'], | |
input_record['audio_sample_rate'], input_record['inst_names'], | |
input_record['track_id'] | |
]) | |
ds = _include_inputs(ds, input_record, fields_to_omit=['mix', 'stems']) | |
return ds | |
tokenized_records = ds.flat_map(process_record) | |
return tokenized_records | |
def compute_spectrograms(ex, spectrogram_config): | |
samples = spectrograms.flatten_frames(ex['inputs']) | |
ex['inputs'] = spectrograms.compute_spectrogram(samples, spectrogram_config) | |
ex['raw_inputs'] = samples | |
return ex | |
def handle_too_long(dataset: tf.data.Dataset, | |
output_features: seqio.preprocessors.OutputFeaturesType, | |
sequence_length: seqio.preprocessors.SequenceLengthType, | |
skip: bool = False) -> tf.data.Dataset: | |
"""Handle sequences that are too long, by either failing or skipping them.""" | |
def max_length_for_key(key): | |
max_length = sequence_length[key] | |
if output_features[key].add_eos: | |
max_length -= 1 | |
return max_length | |
if skip: | |
# Drop examples where one of the features is longer than its maximum | |
# sequence length. | |
def is_not_too_long(ex): | |
return not tf.reduce_any( | |
[k in output_features and len(v) > max_length_for_key(k) | |
for k, v in ex.items()]) | |
dataset = dataset.filter(is_not_too_long) | |
def assert_not_too_long(key: str, value: tf.Tensor) -> tf.Tensor: | |
if key in output_features: | |
max_length = max_length_for_key(key) | |
tf.debugging.assert_less_equal( | |
tf.shape(value)[0], max_length, | |
f'Value for "{key}" field exceeds maximum length') | |
return value | |
# Assert that no examples have features longer than their maximum sequence | |
# length. | |
return dataset.map( | |
lambda ex: {k: assert_not_too_long(k, v) for k, v in ex.items()}, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def map_midi_programs( | |
ds: tf.data.Dataset, | |
codec: event_codec.Codec, | |
granularity_type: str = 'full', | |
feature_key: str = 'targets' | |
) -> Mapping[str, Any]: | |
"""Apply MIDI program map to token sequences.""" | |
granularity = vocabularies.PROGRAM_GRANULARITIES[granularity_type] | |
def _map_program_tokens(ex): | |
ex[feature_key] = granularity.tokens_map_fn(ex[feature_key], codec) | |
return ex | |
return ds.map(_map_program_tokens, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |