Hmjz100 commited on
Commit
bc4a291
·
1 Parent(s): 8d308dd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -226
app.py CHANGED
@@ -1,37 +1,50 @@
 
1
  import os
2
- os.system("pip install gradio")
 
3
 
4
- import gradio as gr
5
- from pathlib import Path
6
- os.system("pip install gsutil")
7
 
 
8
 
 
 
 
 
9
  os.system("git clone --branch=main https://github.com/google-research/t5x")
 
10
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
 
11
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
 
12
  os.system("python3 -m pip install -e .")
 
13
  os.system("python3 -m pip install --upgrade pip")
14
 
15
 
16
  # 安装 mt3
 
17
  os.system("git clone --branch=main https://github.com/magenta/mt3")
 
18
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
 
19
  os.system("python3 -m pip install -e .")
 
20
  os.system("pip install tensorflow_cpu")
 
21
  # 复制检查点
 
22
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
23
 
24
  # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
 
25
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
26
 
27
  #@title 导入和定义
28
-
29
-
30
  import functools
31
- import os
32
 
33
  import numpy as np
34
-
35
  import tensorflow.compat.v2 as tf
36
 
37
  import functools
@@ -40,7 +53,6 @@ import jax
40
  import librosa
41
  import note_seq
42
 
43
-
44
  import seqio
45
  import t5
46
  import t5x
@@ -61,224 +73,232 @@ SAMPLE_RATE = 16000
61
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
62
 
63
  def upload_audio(audio, sample_rate):
64
- return note_seq.audio_io.wav_data_to_samples_librosa(
65
- audio, sample_rate=sample_rate)
66
-
67
 
68
 
 
69
  class InferenceModel(object):
70
- """音乐转录的 T5X 模型包装器。"""
71
-
72
- def __init__(self, checkpoint_path, model_type='mt3'):
73
-
74
- # 模型常量。
75
- if model_type == 'ismir2021':
76
- num_velocity_bins = 127
77
- self.encoding_spec = note_sequences.NoteEncodingSpec
78
- self.inputs_length = 512
79
- elif model_type == 'mt3':
80
- num_velocity_bins = 1
81
- self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
82
- self.inputs_length = 256
83
- else:
84
- raise ValueError('unknown model_type: %s' % model_type)
85
-
86
- gin_files = ['/home/user/app/mt3/gin/model.gin',
87
- '/home/user/app/mt3/gin/mt3.gin']
88
-
89
- self.batch_size = 8
90
- self.outputs_length = 1024
91
- self.sequence_length = {'inputs': self.inputs_length,
92
- 'targets': self.outputs_length}
93
-
94
- self.partitioner = t5x.partitioning.PjitPartitioner(
95
- model_parallel_submesh=None, num_partitions=1)
96
-
97
- # 构建编解码器和词汇表。
98
- self.spectrogram_config = spectrograms.SpectrogramConfig()
99
- self.codec = vocabularies.build_codec(
100
- vocab_config=vocabularies.VocabularyConfig(
101
- num_velocity_bins=num_velocity_bins))
102
- self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
103
- self.output_features = {
104
- 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
105
- 'targets': seqio.Feature(vocabulary=self.vocabulary),
106
- }
107
-
108
- # 创建 T5X 模型。
109
- self._parse_gin(gin_files)
110
- self.model = self._load_model()
111
-
112
- # 从检查点中恢复。
113
- self.restore_from_checkpoint(checkpoint_path)
114
-
115
- @property
116
- def input_shapes(self):
117
- return {
118
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
119
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
120
- }
121
-
122
- def _parse_gin(self, gin_files):
123
- """解析用于训练模型的 gin 文件。"""
124
- gin_bindings = [
125
- 'from __gin__ import dynamic_registration',
126
- 'from mt3 import vocabularies',
127
- 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
128
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
129
- ]
130
- with gin.unlock_config():
131
- gin.parse_config_files_and_bindings(
132
- gin_files, gin_bindings, finalize_config=False)
133
-
134
- def _load_model(self):
135
- """在解析训练 gin 配置后加载 T5X `Model`。"""
136
- model_config = gin.get_configurable(network.T5Config)()
137
- module = network.Transformer(config=model_config)
138
- return models.ContinuousInputsEncoderDecoderModel(
139
- module=module,
140
- input_vocabulary=self.output_features['inputs'].vocabulary,
141
- output_vocabulary=self.output_features['targets'].vocabulary,
142
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
143
- input_depth=spectrograms.input_depth(self.spectrogram_config))
144
-
145
-
146
- def restore_from_checkpoint(self, checkpoint_path):
147
- """从检查点中恢复训练状态,重置 self._predict_fn()。"""
148
- train_state_initializer = t5x.utils.TrainStateInitializer(
149
- optimizer_def=self.model.optimizer_def,
150
- init_fn=self.model.get_initial_variables,
151
- input_shapes=self.input_shapes,
152
- partitioner=self.partitioner)
153
-
154
- restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
155
- path=checkpoint_path, mode='specific', dtype='float32')
156
-
157
- train_state_axes = train_state_initializer.train_state_axes
158
- self._predict_fn = self._get_predict_fn(train_state_axes)
159
- self._train_state = train_state_initializer.from_checkpoint_or_scratch(
160
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
161
-
162
- @functools.lru_cache()
163
- def _get_predict_fn(self, train_state_axes):
164
- """生成一个分区的预测函数用于解码。"""
165
- def partial_predict_fn(params, batch, decode_rng):
166
- return self.model.predict_batch_with_aux(
167
- params, batch, decoder_params={'decode_rng': None})
168
- return self.partitioner.partition(
169
- partial_predict_fn,
170
- in_axis_resources=(
171
- train_state_axes.params,
172
- t5x.partitioning.PartitionSpec('data',), None),
173
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
174
- )
175
-
176
- def predict_tokens(self, batch, seed=0):
177
- """从预处理的数据集批次中预测 tokens。"""
178
- prediction, _ = self._predict_fn(
 
 
 
 
 
 
 
 
 
179
  self._train_state.params, batch, jax.random.PRNGKey(seed))
180
- return self.vocabulary.decode_tf(prediction).numpy()
181
-
182
- def __call__(self, audio):
183
- """从音频样本推断出音符序列。
184
-
185
- 参数:
186
- audio:16kHz 的单个音频样本的 1 维 numpy 数组。
187
- 返回:
188
- 转录音频的音符序列。
189
- """
190
- ds = self.audio_to_dataset(audio)
191
- ds = self.preprocess(ds)
192
-
193
- model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
194
- ds, task_feature_lengths=self.sequence_length)
195
- model_ds = model_ds.batch(self.batch_size)
196
-
197
- inferences = (tokens for batch in model_ds.as_numpy_iterator()
198
- for tokens in self.predict_tokens(batch))
199
-
200
- predictions = []
201
- for example, tokens in zip(ds.as_numpy_iterator(), inferences):
202
- predictions.append(self.postprocess(tokens, example))
203
-
204
- result = metrics_utils.event_predictions_to_ns(
205
- predictions, codec=self.codec, encoding_spec=self.encoding_spec)
206
- return result['est_ns']
207
-
208
- def audio_to_dataset(self, audio):
209
- """从输入音频创建一个包含频谱图的 TF Dataset。"""
210
- frames, frame_times = self._audio_to_frames(audio)
211
- return tf.data.Dataset.from_tensors({
212
- 'inputs': frames,
213
- 'input_times': frame_times,
214
- })
215
-
216
- def _audio_to_frames(self, audio):
217
- """从音频计算频谱图帧。"""
218
- frame_size = self.spectrogram_config.hop_width
219
- padding = [0, frame_size - len(audio) % frame_size]
220
- audio = np.pad(audio, padding, mode='constant')
221
- frames = spectrograms.split_audio(audio, self.spectrogram_config)
222
- num_frames = len(audio) // frame_size
223
- times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
224
- return frames, times
225
-
226
- def preprocess(self, ds):
227
- pp_chain = [
228
- functools.partial(
229
- t5.data.preprocessors.split_tokens_to_inputs_length,
230
- sequence_length=self.sequence_length,
231
- output_features=self.output_features,
232
- feature_key='inputs',
233
- additional_feature_keys=['input_times']),
234
- # 在训练期间进行缓存。
235
- preprocessors.add_dummy_targets,
236
- functools.partial(
237
- preprocessors.compute_spectrograms,
238
- spectrogram_config=self.spectrogram_config)
239
- ]
240
- for pp in pp_chain:
241
- ds = pp(ds)
242
- return ds
243
-
244
- def postprocess(self, tokens, example):
245
- tokens = self._trim_eos(tokens)
246
- start_time = example['input_times'][0]
247
- # 向下取整到最接近的符号化时间步。
248
- start_time -= start_time % (1 / self.codec.steps_per_second)
249
- return {
250
- 'est_tokens': tokens,
251
- 'start_time': start_time,
252
- # 内部 MT3 代码期望原始输入,这里不使用。
253
- 'raw_inputs': []
254
- }
255
-
256
- @staticmethod
257
- def _trim_eos(tokens):
258
- tokens = np.array(tokens, np.int32)
259
- if vocabularies.DECODED_EOS_ID in tokens:
260
- tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
261
- return tokens
 
 
 
262
 
263
 
 
264
 
265
 
 
 
 
 
266
 
 
267
 
268
- inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
269
 
 
270
 
271
- def inference(audio):
272
- with open(audio, 'rb') as fd:
273
- contents = fd.read()
274
- audio = upload_audio(contents,sample_rate=16000)
275
-
276
- est_ns = inference_model(audio)
277
-
278
- note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
279
-
280
- return './transcribed.mid'
281
-
282
  title = "MT3"
283
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以加载它们。更多信息请参阅下面的链接。"
284
 
@@ -287,14 +307,14 @@ article = "<p style='text-align: center'>出错了?试试把文件转换为MP3
287
  examples=[['canon.flac'], ['download.wav']]
288
 
289
  gr.Interface(
290
- inference,
291
- gr.inputs.Audio(type="filepath", label="输入"),
292
- [gr.outputs.File(label="输出")],
293
- title=title,
294
- description=description,
295
- article=article,
296
- examples=examples,
297
- allow_flagging=False,
298
- allow_screenshot=False,
299
- enable_queue=True
300
- ).launch()
 
1
+ import gradio as gr
2
  import os
3
+ import datetime
4
+ import pytz
5
 
6
+ current_time = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
 
 
7
 
8
+ print(f"[{current_time}] 日志: - 部署空间")
9
 
10
+ from pathlib import Path
11
+ print(f"[{current_time}] 日志: - 安装 gsutil")
12
+ os.system("pip install gsutil")
13
+ print(f"[{current_time}] 日志: - 从 Github 克隆 T5X 训练框架")
14
  os.system("git clone --branch=main https://github.com/google-research/t5x")
15
+ print(f"[{current_time}] 日志: - 将 T5X 训练框架转变为临时文件")
16
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
17
+ print(f"[{current_time}] 日志: - 修改当前目录下的 setup.py 内的 jax[tpu] 为 jax")
18
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
19
+ print(f"[{current_time}] 日志: - 安装当前目录中的 Python 包")
20
  os.system("python3 -m pip install -e .")
21
+ print(f"[{current_time}] 日志: - 更新 Python 包管理器 pip 到最新版")
22
  os.system("python3 -m pip install --upgrade pip")
23
 
24
 
25
  # 安装 mt3
26
+ print(f"[{current_time}] 日志: - 从 Github 克隆 MT3 模型")
27
  os.system("git clone --branch=main https://github.com/magenta/mt3")
28
+ print(f"[{current_time}] 日志: - 将 MT3 模型转变为临时文件")
29
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
30
+ print(f"[{current_time}] 日志: - 安装当前目录中的 Python 包")
31
  os.system("python3 -m pip install -e .")
32
+ print(f"[{current_time}] 日志: - 安装 TensorFlow CPU版")
33
  os.system("pip install tensorflow_cpu")
34
+
35
  # 复制检查点
36
+ print(f"[{current_time}] 日志: - 复制 MT3 内的检查点到当前目录")
37
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
38
 
39
  # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
40
+ print(f"[{current_time}] 日志: - 复制 SoundFont 文件到当前目录")
41
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
42
 
43
  #@title 导入和定义
44
+ print(f"[{current_time}] 日志: - 导入实用命令")
 
45
  import functools
 
46
 
47
  import numpy as np
 
48
  import tensorflow.compat.v2 as tf
49
 
50
  import functools
 
53
  import librosa
54
  import note_seq
55
 
 
56
  import seqio
57
  import t5
58
  import t5x
 
73
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
74
 
75
  def upload_audio(audio, sample_rate):
76
+ return note_seq.audio_io.wav_data_to_samples_librosa(
77
+ audio, sample_rate=sample_rate)
 
78
 
79
 
80
+ print(f"[{current_time}] 日志: - 包装模型")
81
  class InferenceModel(object):
82
+ """音乐转录的 T5X 模型包装器。"""
83
+
84
+ def __init__(self, checkpoint_path, model_type='mt3'):
85
+
86
+ # 模型常量。
87
+ print(f"[{current_time}] 日志: - 设置模型常量")
88
+ if model_type == 'ismir2021':
89
+ num_velocity_bins = 127
90
+ self.encoding_spec = note_sequences.NoteEncodingSpec
91
+ self.inputs_length = 512
92
+ elif model_type == 'mt3':
93
+ num_velocity_bins = 1
94
+ self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
95
+ self.inputs_length = 256
96
+ else:
97
+ raise ValueError('unknown model_type: %s' % model_type)
98
+
99
+ gin_files = ['/home/user/app/mt3/gin/model.gin',
100
+ '/home/user/app/mt3/gin/mt3.gin']
101
+
102
+ self.batch_size = 8
103
+ self.outputs_length = 1024
104
+ self.sequence_length = {'inputs': self.inputs_length,
105
+ 'targets': self.outputs_length}
106
+
107
+ self.partitioner = t5x.partitioning.PjitPartitioner(
108
+ model_parallel_submesh=None, num_partitions=1)
109
+
110
+ # 构建编解码器和词汇表。
111
+ print(f"[{current_time}] 日志: - 构建编解码器")
112
+ self.spectrogram_config = spectrograms.SpectrogramConfig()
113
+ self.codec = vocabularies.build_codec(
114
+ vocab_config=vocabularies.VocabularyConfig(
115
+ num_velocity_bins=num_velocity_bins))
116
+ self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
117
+ self.output_features = {
118
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
119
+ 'targets': seqio.Feature(vocabulary=self.vocabulary),
120
+ }
121
+
122
+ # 创建 T5X 模型。
123
+ print(f"[{current_time}] 日志: - 创建 T5X 模型")
124
+ self._parse_gin(gin_files)
125
+ self.model = self._load_model()
126
+
127
+ # 从检查点中恢复。
128
+ print(f"[{current_time}] 日志: - 恢复检查点")
129
+ self.restore_from_checkpoint(checkpoint_path)
130
+
131
+ @property
132
+ def input_shapes(self):
133
+ return {
134
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
135
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
136
+ }
137
+
138
+ def _parse_gin(self, gin_files):
139
+ """解析用于训练模型的 gin 文件。"""
140
+ print(f"[{current_time}] 日志: - 解析 gin 文件")
141
+ gin_bindings = [
142
+ 'from __gin__ import dynamic_registration',
143
+ 'from mt3 import vocabularies',
144
+ 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
145
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
146
+ ]
147
+ with gin.unlock_config():
148
+ gin.parse_config_files_and_bindings(
149
+ gin_files, gin_bindings, finalize_config=False)
150
+
151
+ def _load_model(self):
152
+ """在解析训练 gin 配置后加载 T5X `Model`。"""
153
+ print(f"[{current_time}] 日志: - 加载 T5X 模型")
154
+ model_config = gin.get_configurable(network.T5Config)()
155
+ module = network.Transformer(config=model_config)
156
+ return models.ContinuousInputsEncoderDecoderModel(
157
+ module=module,
158
+ input_vocabulary=self.output_features['inputs'].vocabulary,
159
+ output_vocabulary=self.output_features['targets'].vocabulary,
160
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
161
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
162
+
163
+
164
+ def restore_from_checkpoint(self, checkpoint_path):
165
+ """从检查点中恢复训练状态,重置 self._predict_fn()。"""
166
+ print(f"[{current_time}] 日志: - 从检查点恢复训练状态")
167
+ train_state_initializer = t5x.utils.TrainStateInitializer(
168
+ optimizer_def=self.model.optimizer_def,
169
+ init_fn=self.model.get_initial_variables,
170
+ input_shapes=self.input_shapes,
171
+ partitioner=self.partitioner)
172
+
173
+ restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
174
+ path=checkpoint_path, mode='specific', dtype='float32')
175
+
176
+ train_state_axes = train_state_initializer.train_state_axes
177
+ self._predict_fn = self._get_predict_fn(train_state_axes)
178
+ self._train_state = train_state_initializer.from_checkpoint_or_scratch(
179
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
180
+
181
+ @functools.lru_cache()
182
+ def _get_predict_fn(self, train_state_axes):
183
+ """生成一个分区的预测函数用于解码。"""
184
+ print(f"[{current_time}] 日志: - 生成用于解码的预测函数")
185
+ def partial_predict_fn(params, batch, decode_rng):
186
+ return self.model.predict_batch_with_aux(
187
+ params, batch, decoder_params={'decode_rng': None})
188
+ return self.partitioner.partition(
189
+ partial_predict_fn,
190
+ in_axis_resources=(
191
+ train_state_axes.params,
192
+ t5x.partitioning.PartitionSpec('data',), None),
193
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
194
+ )
195
+
196
+ def predict_tokens(self, batch, seed=0):
197
+ """从预处理的数据集批次中预测 tokens。"""
198
+ print(f"[{current_time}] 日志: - 从数据集中预测 tokens")
199
+ prediction, _ = self._predict_fn(
200
  self._train_state.params, batch, jax.random.PRNGKey(seed))
201
+ return self.vocabulary.decode_tf(prediction).numpy()
202
+
203
+ def __call__(self, audio):
204
+ """从音频样本推断出音符序列。
205
+
206
+ 参数:
207
+ audio:16kHz 的单个音频样本的 1 维 numpy 数组。
208
+ 返回:
209
+ 转录音频的音符序列。
210
+ """
211
+ print(f"[{current_time}] 日志: - 推断音符序列")
212
+ ds = self.audio_to_dataset(audio)
213
+ ds = self.preprocess(ds)
214
+
215
+ model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
216
+ ds, task_feature_lengths=self.sequence_length)
217
+ model_ds = model_ds.batch(self.batch_size)
218
+
219
+ inferences = (tokens for batch in model_ds.as_numpy_iterator()
220
+ for tokens in self.predict_tokens(batch))
221
+
222
+ predictions = []
223
+ for example, tokens in zip(ds.as_numpy_iterator(), inferences):
224
+ predictions.append(self.postprocess(tokens, example))
225
+
226
+ result = metrics_utils.event_predictions_to_ns(
227
+ predictions, codec=self.codec, encoding_spec=self.encoding_spec)
228
+ return result['est_ns']
229
+
230
+ def audio_to_dataset(self, audio):
231
+ """从输入音频创建一个包含频谱图的 TF Dataset。"""
232
+ print(f"[{current_time}] 日志: - 创建 TF Dataset")
233
+ frames, frame_times = self._audio_to_frames(audio)
234
+ return tf.data.Dataset.from_tensors({
235
+ 'inputs': frames,
236
+ 'input_times': frame_times,
237
+ })
238
+
239
+ def _audio_to_frames(self, audio):
240
+ """从音频计算频谱图帧。"""
241
+ print(f"[{current_time}] 日志: - 计算频谱图帧")
242
+ frame_size = self.spectrogram_config.hop_width
243
+ padding = [0, frame_size - len(audio) % frame_size]
244
+ audio = np.pad(audio, padding, mode='constant')
245
+ frames = spectrograms.split_audio(audio, self.spectrogram_config)
246
+ num_frames = len(audio) // frame_size
247
+ times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
248
+ return frames, times
249
+
250
+ def preprocess(self, ds):
251
+ pp_chain = [
252
+ functools.partial(
253
+ t5.data.preprocessors.split_tokens_to_inputs_length,
254
+ sequence_length=self.sequence_length,
255
+ output_features=self.output_features,
256
+ feature_key='inputs',
257
+ additional_feature_keys=['input_times']),
258
+ # 在训练期间进行缓存。
259
+ preprocessors.add_dummy_targets,
260
+ functools.partial(
261
+ preprocessors.compute_spectrograms,
262
+ spectrogram_config=self.spectrogram_config)
263
+ ]
264
+ for pp in pp_chain:
265
+ ds = pp(ds)
266
+ return ds
267
+
268
+ def postprocess(self, tokens, example):
269
+ tokens = self._trim_eos(tokens)
270
+ start_time = example['input_times'][0]
271
+ # 向下取整到最接近的符号化时间步。
272
+ start_time -= start_time % (1 / self.codec.steps_per_second)
273
+ return {
274
+ 'est_tokens': tokens,
275
+ 'start_time': start_time,
276
+ # 内部 MT3 代码期望原始输入,这里不使用。
277
+ 'raw_inputs': []
278
+ }
279
+
280
+ @staticmethod
281
+ def _trim_eos(tokens):
282
+ tokens = np.array(tokens, np.int32)
283
+ if vocabularies.DECODED_EOS_ID in tokens:
284
+ tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
285
+ return tokens
286
 
287
 
288
+ inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
289
 
290
 
291
+ def inference(audio):
292
+ with open(audio, 'rb') as fd:
293
+ contents = fd.read()
294
+ audio = upload_audio(contents,sample_rate=16000)
295
 
296
+ est_ns = inference_model(audio)
297
 
298
+ note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
299
 
300
+ return './transcribed.mid'
301
 
 
 
 
 
 
 
 
 
 
 
 
302
  title = "MT3"
303
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以加载它们。更多信息请参阅下面的链接。"
304
 
 
307
  examples=[['canon.flac'], ['download.wav']]
308
 
309
  gr.Interface(
310
+ inference,
311
+ gr.inputs.Audio(type="filepath", label="输入"),
312
+ [gr.outputs.File(label="输出")],
313
+ title=title,
314
+ description=description,
315
+ article=article,
316
+ examples=examples,
317
+ allow_flagging=False,
318
+ allow_screenshot=False,
319
+ enable_queue=True
320
+ ).launch()