|
import gradio as gr |
|
import os |
|
from pathlib import Path |
|
os.system("pip install gsutil") |
|
|
|
|
|
os.system("git clone --branch=main https://github.com/google-research/t5x") |
|
os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp") |
|
os.system("sed -i 's:jax\[tpu\]:jax:' setup.py") |
|
os.system("python3 -m pip install -e .") |
|
|
|
|
|
os.system("git clone --branch=main https://github.com/magenta/mt3") |
|
os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp") |
|
os.system("python3 -m pip install -e .") |
|
|
|
|
|
os.system("gsutil -q -m cp -r gs://mt3/checkpoints .") |
|
|
|
|
|
os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .") |
|
|
|
|
|
|
|
import functools |
|
import os |
|
|
|
import numpy as np |
|
import tensorflow.compat.v2 as tf |
|
|
|
import functools |
|
import gin |
|
import jax |
|
import librosa |
|
import note_seq |
|
import seqio |
|
import t5 |
|
import t5x |
|
|
|
from mt3 import metrics_utils |
|
from mt3 import models |
|
from mt3 import network |
|
from mt3 import note_sequences |
|
from mt3 import preprocessors |
|
from mt3 import spectrograms |
|
from mt3 import vocabularies |
|
|
|
|
|
import nest_asyncio |
|
nest_asyncio.apply() |
|
|
|
SAMPLE_RATE = 16000 |
|
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2' |
|
|
|
def upload_audio(audio, sample_rate): |
|
return note_seq.audio_io.wav_data_to_samples_librosa( |
|
audio, sample_rate=sample_rate) |
|
|
|
|
|
|
|
class InferenceModel(object): |
|
"""Wrapper of T5X model for music transcription.""" |
|
|
|
def __init__(self, checkpoint_path, model_type='mt3'): |
|
|
|
|
|
if model_type == 'ismir2021': |
|
num_velocity_bins = 127 |
|
self.encoding_spec = note_sequences.NoteEncodingSpec |
|
self.inputs_length = 512 |
|
elif model_type == 'mt3': |
|
num_velocity_bins = 1 |
|
self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec |
|
self.inputs_length = 256 |
|
else: |
|
raise ValueError('unknown model_type: %s' % model_type) |
|
|
|
gin_files = ['/home/user/app/mt3/gin/model.gin', |
|
'/home/user/app/mt3/gin/mt3.gin'] |
|
|
|
self.batch_size = 8 |
|
self.outputs_length = 1024 |
|
self.sequence_length = {'inputs': self.inputs_length, |
|
'targets': self.outputs_length} |
|
|
|
self.partitioner = t5x.partitioning.ModelBasedPjitPartitioner( |
|
model_parallel_submesh=(1, 1, 1, 1), num_partitions=1) |
|
|
|
|
|
self.spectrogram_config = spectrograms.SpectrogramConfig() |
|
self.codec = vocabularies.build_codec( |
|
vocab_config=vocabularies.VocabularyConfig( |
|
num_velocity_bins=num_velocity_bins)) |
|
self.vocabulary = vocabularies.vocabulary_from_codec(self.codec) |
|
self.output_features = { |
|
'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2), |
|
'targets': seqio.Feature(vocabulary=self.vocabulary), |
|
} |
|
|
|
|
|
self._parse_gin(gin_files) |
|
self.model = self._load_model() |
|
|
|
|
|
self.restore_from_checkpoint(checkpoint_path) |
|
|
|
@property |
|
def input_shapes(self): |
|
return { |
|
'encoder_input_tokens': (self.batch_size, self.inputs_length), |
|
'decoder_input_tokens': (self.batch_size, self.outputs_length) |
|
} |
|
|
|
def _parse_gin(self, gin_files): |
|
"""Parse gin files used to train the model.""" |
|
gin_bindings = [ |
|
'from __gin__ import dynamic_registration', |
|
'from mt3 import vocabularies', |
|
'VOCAB_CONFIG=@vocabularies.VocabularyConfig()', |
|
'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS' |
|
] |
|
with gin.unlock_config(): |
|
gin.parse_config_files_and_bindings( |
|
gin_files, gin_bindings, finalize_config=False) |
|
|
|
def _load_model(self): |
|
"""Load up a T5X `Model` after parsing training gin config.""" |
|
model_config = gin.get_configurable(network.T5Config)() |
|
module = network.Transformer(config=model_config) |
|
return models.ContinuousInputsEncoderDecoderModel( |
|
module=module, |
|
input_vocabulary=self.output_features['inputs'].vocabulary, |
|
output_vocabulary=self.output_features['targets'].vocabulary, |
|
optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0), |
|
input_depth=spectrograms.input_depth(self.spectrogram_config)) |
|
|
|
|
|
def restore_from_checkpoint(self, checkpoint_path): |
|
"""Restore training state from checkpoint, resets self._predict_fn().""" |
|
train_state_initializer = t5x.utils.TrainStateInitializer( |
|
optimizer_def=self.model.optimizer_def, |
|
init_fn=self.model.get_initial_variables, |
|
input_shapes=self.input_shapes, |
|
partitioner=self.partitioner) |
|
|
|
restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig( |
|
path=checkpoint_path, mode='specific', dtype='float32') |
|
|
|
train_state_axes = train_state_initializer.train_state_axes |
|
self._predict_fn = self._get_predict_fn(train_state_axes) |
|
self._train_state = train_state_initializer.from_checkpoint_or_scratch( |
|
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)) |
|
|
|
@functools.lru_cache() |
|
def _get_predict_fn(self, train_state_axes): |
|
"""Generate a partitioned prediction function for decoding.""" |
|
def partial_predict_fn(params, batch, decode_rng): |
|
return self.model.predict_batch_with_aux( |
|
params, batch, decoder_params={'decode_rng': None}) |
|
return self.partitioner.partition( |
|
partial_predict_fn, |
|
in_axis_resources=( |
|
train_state_axes.params, |
|
t5x.partitioning.PartitionSpec('data',), None), |
|
out_axis_resources=t5x.partitioning.PartitionSpec('data',) |
|
) |
|
|
|
def predict_tokens(self, batch, seed=0): |
|
"""Predict tokens from preprocessed dataset batch.""" |
|
prediction, _ = self._predict_fn( |
|
self._train_state.params, batch, jax.random.PRNGKey(seed)) |
|
return self.vocabulary.decode_tf(prediction).numpy() |
|
|
|
def __call__(self, audio): |
|
"""Infer note sequence from audio samples. |
|
|
|
Args: |
|
audio: 1-d numpy array of audio samples (16kHz) for a single example. |
|
|
|
Returns: |
|
A note_sequence of the transcribed audio. |
|
""" |
|
ds = self.audio_to_dataset(audio) |
|
ds = self.preprocess(ds) |
|
|
|
model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)( |
|
ds, task_feature_lengths=self.sequence_length) |
|
model_ds = model_ds.batch(self.batch_size) |
|
|
|
inferences = (tokens for batch in model_ds.as_numpy_iterator() |
|
for tokens in self.predict_tokens(batch)) |
|
|
|
predictions = [] |
|
for example, tokens in zip(ds.as_numpy_iterator(), inferences): |
|
predictions.append(self.postprocess(tokens, example)) |
|
|
|
result = metrics_utils.event_predictions_to_ns( |
|
predictions, codec=self.codec, encoding_spec=self.encoding_spec) |
|
return result['est_ns'] |
|
|
|
def audio_to_dataset(self, audio): |
|
"""Create a TF Dataset of spectrograms from input audio.""" |
|
frames, frame_times = self._audio_to_frames(audio) |
|
return tf.data.Dataset.from_tensors({ |
|
'inputs': frames, |
|
'input_times': frame_times, |
|
}) |
|
|
|
def _audio_to_frames(self, audio): |
|
"""Compute spectrogram frames from audio.""" |
|
frame_size = self.spectrogram_config.hop_width |
|
padding = [0, frame_size - len(audio) % frame_size] |
|
audio = np.pad(audio, padding, mode='constant') |
|
frames = spectrograms.split_audio(audio, self.spectrogram_config) |
|
num_frames = len(audio) // frame_size |
|
times = np.arange(num_frames) / self.spectrogram_config.frames_per_second |
|
return frames, times |
|
|
|
def preprocess(self, ds): |
|
pp_chain = [ |
|
functools.partial( |
|
t5.data.preprocessors.split_tokens_to_inputs_length, |
|
sequence_length=self.sequence_length, |
|
output_features=self.output_features, |
|
feature_key='inputs', |
|
additional_feature_keys=['input_times']), |
|
|
|
preprocessors.add_dummy_targets, |
|
functools.partial( |
|
preprocessors.compute_spectrograms, |
|
spectrogram_config=self.spectrogram_config) |
|
] |
|
for pp in pp_chain: |
|
ds = pp(ds) |
|
return ds |
|
|
|
def postprocess(self, tokens, example): |
|
tokens = self._trim_eos(tokens) |
|
start_time = example['input_times'][0] |
|
|
|
start_time -= start_time % (1 / self.codec.steps_per_second) |
|
return { |
|
'est_tokens': tokens, |
|
'start_time': start_time, |
|
|
|
'raw_inputs': [] |
|
} |
|
|
|
@staticmethod |
|
def _trim_eos(tokens): |
|
tokens = np.array(tokens, np.int32) |
|
if vocabularies.DECODED_EOS_ID in tokens: |
|
tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)] |
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3') |
|
|
|
|
|
def inference(audio): |
|
with open(audio, 'rb') as fd: |
|
contents = fd.read() |
|
audio = upload_audio(contents,sample_rate=16000) |
|
|
|
est_ns = inference_model(audio) |
|
|
|
note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid') |
|
|
|
return './transcribed.mid' |
|
|
|
title = "MT3" |
|
description = "Gradio demo for MT3: Multi-Task Multitrack Music Transcription. To use it, simply upload your audio file, or click one of the examples to load them. Read more at the links below." |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: Multi-Task Multitrack Music Transcription</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github Repo</a></p>" |
|
|
|
examples=[['download.wav']] |
|
|
|
gr.Interface( |
|
inference, |
|
gr.inputs.Audio(type="filepath", label="Input"), |
|
[gr.outputs.File(label="Output")], |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples |
|
).launch(enable_queue=True,cache_examples=True) |