juancopi81 commited on
Commit
13af1af
1 Parent(s): 0f3083a

Change organization of code

Browse files
Files changed (3) hide show
  1. app.py +3 -227
  2. inferencemodel.py +222 -0
  3. requirements.txt +1 -2
app.py CHANGED
@@ -1,241 +1,20 @@
1
- import os
2
  import gradio as gr
3
- from pathlib import Path
4
 
5
- os.system("python3 -m pip install -e .")
6
-
7
- import functools
8
- import os
9
-
10
- import numpy as np
11
- import tensorflow.compat.v2 as tf
12
- from pydub import AudioSegment
13
-
14
- import functools
15
- import gin
16
- import jax
17
- import librosa
18
  import note_seq
19
- import seqio
20
- import t5
21
- import t5x
22
-
23
- from mt3 import metrics_utils
24
- from mt3 import models
25
- from mt3 import network
26
- from mt3 import note_sequences
27
- from mt3 import preprocessors
28
- from mt3 import spectrograms
29
- from mt3 import vocabularies
30
-
31
 
32
  import nest_asyncio
33
  nest_asyncio.apply()
34
 
 
 
35
  SAMPLE_RATE = 16000
36
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
37
 
38
  def upload_audio(audio, sample_rate):
 
39
  return note_seq.audio_io.wav_data_to_samples_librosa(
40
  audio, sample_rate=sample_rate)
41
 
42
-
43
-
44
- class InferenceModel(object):
45
- """Wrapper of T5X model for music transcription."""
46
-
47
- def __init__(self, checkpoint_path, model_type='mt3'):
48
-
49
- # Model Constants.
50
- if model_type == 'ismir2021':
51
- num_velocity_bins = 127
52
- self.encoding_spec = note_sequences.NoteEncodingSpec
53
- self.inputs_length = 512
54
- elif model_type == 'mt3':
55
- num_velocity_bins = 1
56
- self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
57
- self.inputs_length = 256
58
- else:
59
- raise ValueError('unknown model_type: %s' % model_type)
60
-
61
- gin_files = ['/home/user/app/mt3/gin/model.gin',
62
- '/home/user/app/mt3/gin/mt3.gin']
63
-
64
- self.batch_size = 8
65
- self.outputs_length = 1024
66
- self.sequence_length = {'inputs': self.inputs_length,
67
- 'targets': self.outputs_length}
68
-
69
- self.partitioner = t5x.partitioning.PjitPartitioner(
70
- model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
71
-
72
- # Build Codecs and Vocabularies.
73
- self.spectrogram_config = spectrograms.SpectrogramConfig()
74
- self.codec = vocabularies.build_codec(
75
- vocab_config=vocabularies.VocabularyConfig(
76
- num_velocity_bins=num_velocity_bins))
77
- self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
78
- self.output_features = {
79
- 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
80
- 'targets': seqio.Feature(vocabulary=self.vocabulary),
81
- }
82
-
83
- # Create a T5X model.
84
- self._parse_gin(gin_files)
85
- self.model = self._load_model()
86
-
87
- # Restore from checkpoint.
88
- self.restore_from_checkpoint(checkpoint_path)
89
-
90
- @property
91
- def input_shapes(self):
92
- return {
93
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
94
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
95
- }
96
-
97
- def _parse_gin(self, gin_files):
98
- """Parse gin files used to train the model."""
99
- gin_bindings = [
100
- 'from __gin__ import dynamic_registration',
101
- 'from mt3 import vocabularies',
102
- 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
103
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
104
- ]
105
- with gin.unlock_config():
106
- gin.parse_config_files_and_bindings(
107
- gin_files, gin_bindings, finalize_config=False)
108
-
109
- def _load_model(self):
110
- """Load up a T5X `Model` after parsing training gin config."""
111
- model_config = gin.get_configurable(network.T5Config)()
112
- module = network.Transformer(config=model_config)
113
- return models.ContinuousInputsEncoderDecoderModel(
114
- module=module,
115
- input_vocabulary=self.output_features['inputs'].vocabulary,
116
- output_vocabulary=self.output_features['targets'].vocabulary,
117
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
118
- input_depth=spectrograms.input_depth(self.spectrogram_config))
119
-
120
-
121
- def restore_from_checkpoint(self, checkpoint_path):
122
- """Restore training state from checkpoint, resets self._predict_fn()."""
123
- train_state_initializer = t5x.utils.TrainStateInitializer(
124
- optimizer_def=self.model.optimizer_def,
125
- init_fn=self.model.get_initial_variables,
126
- input_shapes=self.input_shapes,
127
- partitioner=self.partitioner)
128
-
129
- restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
130
- path=checkpoint_path, mode='specific', dtype='float32')
131
-
132
- train_state_axes = train_state_initializer.train_state_axes
133
- self._predict_fn = self._get_predict_fn(train_state_axes)
134
- self._train_state = train_state_initializer.from_checkpoint_or_scratch(
135
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
136
-
137
- @functools.lru_cache()
138
- def _get_predict_fn(self, train_state_axes):
139
- """Generate a partitioned prediction function for decoding."""
140
- def partial_predict_fn(params, batch, decode_rng):
141
- return self.model.predict_batch_with_aux(
142
- params, batch, decoder_params={'decode_rng': None})
143
- return self.partitioner.partition(
144
- partial_predict_fn,
145
- in_axis_resources=(
146
- train_state_axes.params,
147
- t5x.partitioning.PartitionSpec('data',), None),
148
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
149
- )
150
-
151
- def predict_tokens(self, batch, seed=0):
152
- """Predict tokens from preprocessed dataset batch."""
153
- prediction, _ = self._predict_fn(
154
- self._train_state.params, batch, jax.random.PRNGKey(seed))
155
- return self.vocabulary.decode_tf(prediction).numpy()
156
-
157
- def __call__(self, audio):
158
- """Infer note sequence from audio samples.
159
-
160
- Args:
161
- audio: 1-d numpy array of audio samples (16kHz) for a single example.
162
- Returns:
163
- A note_sequence of the transcribed audio.
164
- """
165
- ds = self.audio_to_dataset(audio)
166
- ds = self.preprocess(ds)
167
-
168
- model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
169
- ds, task_feature_lengths=self.sequence_length)
170
- model_ds = model_ds.batch(self.batch_size)
171
-
172
- inferences = (tokens for batch in model_ds.as_numpy_iterator()
173
- for tokens in self.predict_tokens(batch))
174
-
175
- predictions = []
176
- for example, tokens in zip(ds.as_numpy_iterator(), inferences):
177
- predictions.append(self.postprocess(tokens, example))
178
-
179
- result = metrics_utils.event_predictions_to_ns(
180
- predictions, codec=self.codec, encoding_spec=self.encoding_spec)
181
- return result['est_ns']
182
-
183
- def audio_to_dataset(self, audio):
184
- """Create a TF Dataset of spectrograms from input audio."""
185
- frames, frame_times = self._audio_to_frames(audio)
186
- return tf.data.Dataset.from_tensors({
187
- 'inputs': frames,
188
- 'input_times': frame_times,
189
- })
190
-
191
- def _audio_to_frames(self, audio):
192
- """Compute spectrogram frames from audio."""
193
- frame_size = self.spectrogram_config.hop_width
194
- padding = [0, frame_size - len(audio) % frame_size]
195
- audio = np.pad(audio, padding, mode='constant')
196
- frames = spectrograms.split_audio(audio, self.spectrogram_config)
197
- num_frames = len(audio) // frame_size
198
- times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
199
- return frames, times
200
-
201
- def preprocess(self, ds):
202
- pp_chain = [
203
- functools.partial(
204
- t5.data.preprocessors.split_tokens_to_inputs_length,
205
- sequence_length=self.sequence_length,
206
- output_features=self.output_features,
207
- feature_key='inputs',
208
- additional_feature_keys=['input_times']),
209
- # Cache occurs here during training.
210
- preprocessors.add_dummy_targets,
211
- functools.partial(
212
- preprocessors.compute_spectrograms,
213
- spectrogram_config=self.spectrogram_config)
214
- ]
215
- for pp in pp_chain:
216
- ds = pp(ds)
217
- return ds
218
-
219
- def postprocess(self, tokens, example):
220
- tokens = self._trim_eos(tokens)
221
- start_time = example['input_times'][0]
222
- # Round down to nearest symbolic token step.
223
- start_time -= start_time % (1 / self.codec.steps_per_second)
224
- return {
225
- 'est_tokens': tokens,
226
- 'start_time': start_time,
227
- # Internal MT3 code expects raw inputs, not used here.
228
- 'raw_inputs': []
229
- }
230
-
231
- @staticmethod
232
- def _trim_eos(tokens):
233
- tokens = np.array(tokens, np.int32)
234
- if vocabularies.DECODED_EOS_ID in tokens:
235
- tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
236
- return tokens
237
-
238
-
239
  # Start inference model
240
  inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
241
 
@@ -267,7 +46,4 @@ gr.Interface(
267
  description=description,
268
  article=article,
269
  examples=examples,
270
- allow_flagging=False,
271
- allow_screenshot=False,
272
- enable_queue=True
273
  ).launch()
 
 
1
  import gradio as gr
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import note_seq
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import nest_asyncio
6
  nest_asyncio.apply()
7
 
8
+ from inferencemodel import InferenceModel
9
+
10
  SAMPLE_RATE = 16000
11
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
12
 
13
  def upload_audio(audio, sample_rate):
14
+
15
  return note_seq.audio_io.wav_data_to_samples_librosa(
16
  audio, sample_rate=sample_rate)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Start inference model
19
  inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
20
 
 
46
  description=description,
47
  article=article,
48
  examples=examples,
 
 
 
49
  ).launch()
inferencemodel.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("python3 -m pip install -e .")
4
+
5
+ import functools
6
+ import os
7
+
8
+ import numpy as np
9
+ import tensorflow.compat.v2 as tf
10
+
11
+ import functools
12
+ import gin
13
+ import jax
14
+ import seqio
15
+ import t5
16
+ import t5x
17
+
18
+ from mt3 import metrics_utils
19
+ from mt3 import models
20
+ from mt3 import network
21
+ from mt3 import note_sequences
22
+ from mt3 import preprocessors
23
+ from mt3 import spectrograms
24
+ from mt3 import vocabularies
25
+
26
+
27
+ import nest_asyncio
28
+ nest_asyncio.apply()
29
+
30
+ class InferenceModel(object):
31
+ """Wrapper of T5X model for music transcription."""
32
+
33
+ def __init__(self, checkpoint_path, model_type='mt3'):
34
+
35
+ # Model Constants.
36
+ if model_type == 'ismir2021':
37
+ num_velocity_bins = 127
38
+ self.encoding_spec = note_sequences.NoteEncodingSpec
39
+ self.inputs_length = 512
40
+ elif model_type == 'mt3':
41
+ num_velocity_bins = 1
42
+ self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
43
+ self.inputs_length = 256
44
+ else:
45
+ raise ValueError('unknown model_type: %s' % model_type)
46
+
47
+ gin_files = ['/home/user/app/mt3/gin/model.gin',
48
+ '/home/user/app/mt3/gin/mt3.gin']
49
+
50
+ self.batch_size = 8
51
+ self.outputs_length = 1024
52
+ self.sequence_length = {'inputs': self.inputs_length,
53
+ 'targets': self.outputs_length}
54
+
55
+ self.partitioner = t5x.partitioning.PjitPartitioner(
56
+ model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
57
+
58
+ # Build Codecs and Vocabularies.
59
+ self.spectrogram_config = spectrograms.SpectrogramConfig()
60
+ self.codec = vocabularies.build_codec(
61
+ vocab_config=vocabularies.VocabularyConfig(
62
+ num_velocity_bins=num_velocity_bins))
63
+ self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
64
+ self.output_features = {
65
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
66
+ 'targets': seqio.Feature(vocabulary=self.vocabulary),
67
+ }
68
+
69
+ # Create a T5X model.
70
+ self._parse_gin(gin_files)
71
+ self.model = self._load_model()
72
+
73
+ # Restore from checkpoint.
74
+ self.restore_from_checkpoint(checkpoint_path)
75
+
76
+ @property
77
+ def input_shapes(self):
78
+ return {
79
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
80
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
81
+ }
82
+
83
+ def _parse_gin(self, gin_files):
84
+ """Parse gin files used to train the model."""
85
+ gin_bindings = [
86
+ 'from __gin__ import dynamic_registration',
87
+ 'from mt3 import vocabularies',
88
+ 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
89
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
90
+ ]
91
+ with gin.unlock_config():
92
+ gin.parse_config_files_and_bindings(
93
+ gin_files, gin_bindings, finalize_config=False)
94
+
95
+ def _load_model(self):
96
+ """Load up a T5X `Model` after parsing training gin config."""
97
+ model_config = gin.get_configurable(network.T5Config)()
98
+ module = network.Transformer(config=model_config)
99
+ return models.ContinuousInputsEncoderDecoderModel(
100
+ module=module,
101
+ input_vocabulary=self.output_features['inputs'].vocabulary,
102
+ output_vocabulary=self.output_features['targets'].vocabulary,
103
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
104
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
105
+
106
+
107
+ def restore_from_checkpoint(self, checkpoint_path):
108
+ """Restore training state from checkpoint, resets self._predict_fn()."""
109
+ train_state_initializer = t5x.utils.TrainStateInitializer(
110
+ optimizer_def=self.model.optimizer_def,
111
+ init_fn=self.model.get_initial_variables,
112
+ input_shapes=self.input_shapes,
113
+ partitioner=self.partitioner)
114
+
115
+ restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
116
+ path=checkpoint_path, mode='specific', dtype='float32')
117
+
118
+ train_state_axes = train_state_initializer.train_state_axes
119
+ self._predict_fn = self._get_predict_fn(train_state_axes)
120
+ self._train_state = train_state_initializer.from_checkpoint_or_scratch(
121
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
122
+
123
+ @functools.lru_cache()
124
+ def _get_predict_fn(self, train_state_axes):
125
+ """Generate a partitioned prediction function for decoding."""
126
+ def partial_predict_fn(params, batch, decode_rng):
127
+ return self.model.predict_batch_with_aux(
128
+ params, batch, decoder_params={'decode_rng': None})
129
+ return self.partitioner.partition(
130
+ partial_predict_fn,
131
+ in_axis_resources=(
132
+ train_state_axes.params,
133
+ t5x.partitioning.PartitionSpec('data',), None),
134
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
135
+ )
136
+
137
+ def predict_tokens(self, batch, seed=0):
138
+ """Predict tokens from preprocessed dataset batch."""
139
+ prediction, _ = self._predict_fn(
140
+ self._train_state.params, batch, jax.random.PRNGKey(seed))
141
+ return self.vocabulary.decode_tf(prediction).numpy()
142
+
143
+ def __call__(self, audio):
144
+ """Infer note sequence from audio samples.
145
+
146
+ Args:
147
+ audio: 1-d numpy array of audio samples (16kHz) for a single example.
148
+ Returns:
149
+ A note_sequence of the transcribed audio.
150
+ """
151
+ ds = self.audio_to_dataset(audio)
152
+ ds = self.preprocess(ds)
153
+
154
+ model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
155
+ ds, task_feature_lengths=self.sequence_length)
156
+ model_ds = model_ds.batch(self.batch_size)
157
+
158
+ inferences = (tokens for batch in model_ds.as_numpy_iterator()
159
+ for tokens in self.predict_tokens(batch))
160
+
161
+ predictions = []
162
+ for example, tokens in zip(ds.as_numpy_iterator(), inferences):
163
+ predictions.append(self.postprocess(tokens, example))
164
+
165
+ result = metrics_utils.event_predictions_to_ns(
166
+ predictions, codec=self.codec, encoding_spec=self.encoding_spec)
167
+ return result['est_ns']
168
+
169
+ def audio_to_dataset(self, audio):
170
+ """Create a TF Dataset of spectrograms from input audio."""
171
+ frames, frame_times = self._audio_to_frames(audio)
172
+ return tf.data.Dataset.from_tensors({
173
+ 'inputs': frames,
174
+ 'input_times': frame_times,
175
+ })
176
+
177
+ def _audio_to_frames(self, audio):
178
+ """Compute spectrogram frames from audio."""
179
+ frame_size = self.spectrogram_config.hop_width
180
+ padding = [0, frame_size - len(audio) % frame_size]
181
+ audio = np.pad(audio, padding, mode='constant')
182
+ frames = spectrograms.split_audio(audio, self.spectrogram_config)
183
+ num_frames = len(audio) // frame_size
184
+ times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
185
+ return frames, times
186
+
187
+ def preprocess(self, ds):
188
+ pp_chain = [
189
+ functools.partial(
190
+ t5.data.preprocessors.split_tokens_to_inputs_length,
191
+ sequence_length=self.sequence_length,
192
+ output_features=self.output_features,
193
+ feature_key='inputs',
194
+ additional_feature_keys=['input_times']),
195
+ # Cache occurs here during training.
196
+ preprocessors.add_dummy_targets,
197
+ functools.partial(
198
+ preprocessors.compute_spectrograms,
199
+ spectrogram_config=self.spectrogram_config)
200
+ ]
201
+ for pp in pp_chain:
202
+ ds = pp(ds)
203
+ return ds
204
+
205
+ def postprocess(self, tokens, example):
206
+ tokens = self._trim_eos(tokens)
207
+ start_time = example['input_times'][0]
208
+ # Round down to nearest symbolic token step.
209
+ start_time -= start_time % (1 / self.codec.steps_per_second)
210
+ return {
211
+ 'est_tokens': tokens,
212
+ 'start_time': start_time,
213
+ # Internal MT3 code expects raw inputs, not used here.
214
+ 'raw_inputs': []
215
+ }
216
+
217
+ @staticmethod
218
+ def _trim_eos(tokens):
219
+ tokens = np.array(tokens, np.int32)
220
+ if vocabularies.DECODED_EOS_ID in tokens:
221
+ tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
222
+ return tokens
requirements.txt CHANGED
@@ -7,5 +7,4 @@ jax[cpu]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.htm
7
  # pin CLU for python 3.7 compatibility
8
  clu==0.0.7
9
  # pin Orbax to use Checkpointer
10
- orbax==0.0.2
11
- pydub
 
7
  # pin CLU for python 3.7 compatibility
8
  clu==0.0.7
9
  # pin Orbax to use Checkpointer
10
+ orbax==0.0.2