Hmjz100 commited on
Commit
d017266
·
verified ·
1 Parent(s): 7bd3414

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -84
app.py CHANGED
@@ -5,8 +5,8 @@ import pytz
5
  from pathlib import Path
6
 
7
  def current_time():
8
- current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
9
- return current
10
 
11
  print(f"[{current_time()}] 开始部署空间...")
12
 
@@ -83,58 +83,58 @@ SAMPLE_RATE = 16000
83
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
84
 
85
  def upload_audio(audio, sample_rate):
86
- return note_seq.audio_io.wav_data_to_samples_librosa(
87
- audio, sample_rate=sample_rate)
88
 
89
 
90
  print(f"[{current_time()}] 日志:开始包装模型...")
91
  class InferenceModel(object):
92
- """音乐转录的 T5X 模型包装器。"""
93
-
94
- def __init__(self, checkpoint_path, model_type='mt3'):
95
- if model_type == 'ismir2021':
96
- num_velocity_bins = 127
97
- self.encoding_spec = note_sequences.NoteEncodingSpec
98
- self.inputs_length = 512
99
- elif model_type == 'mt3':
100
- num_velocity_bins = 1
101
- self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
102
- self.inputs_length = 256
103
- else:
104
- raise ValueError('unknown model_type: %s' % model_type)
105
-
106
- gin_files = ['/home/user/app/mt3/gin/model.gin',
107
- '/home/user/app/mt3/gin/mt3.gin']
108
-
109
- self.batch_size = 8
110
- self.outputs_length = 1024
111
- self.sequence_length = {'inputs': self.inputs_length,
112
- 'targets': self.outputs_length}
113
-
114
- self.partitioner = t5x.partitioning.PjitPartitioner(
115
- model_parallel_submesh=None, num_partitions=1)
116
-
117
- print(f"[{current_time()}] 日志:构建编解码器")
118
- self.spectrogram_config = spectrograms.SpectrogramConfig()
119
- self.codec = vocabularies.build_codec(
120
- vocab_config=vocabularies.VocabularyConfig(
121
- num_velocity_bins=num_velocity_bins)
122
- )
123
- self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
124
- self.output_features = {
125
- 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
126
- 'targets': seqio.Feature(vocabulary=self.vocabulary),
127
- }
128
-
129
- print(f"[{current_time()}] 日志:创建 T5X 模型")
130
- self._parse_gin(gin_files)
131
- self.model = self._load_model()
132
-
133
- print(f"[{current_time()}] 日志:恢复模型检查点")
134
- self.restore_from_checkpoint(checkpoint_path)
135
-
136
- @property
137
- def input_shapes(self):
138
  return {
139
  'encoder_input_tokens': (self.batch_size, self.inputs_length),
140
  'decoder_input_tokens': (self.batch_size, self.outputs_length)
@@ -144,10 +144,10 @@ class InferenceModel(object):
144
  """解析用于训练模型的 gin 文件。"""
145
  print(f"[{current_time()}] 日志:解析 gin 文件")
146
  gin_bindings = [
147
- 'from __gin__ import dynamic_registration',
148
- 'from mt3 import vocabularies',
149
- 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
150
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
151
  ]
152
  with gin.unlock_config():
153
  gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)
@@ -158,11 +158,11 @@ class InferenceModel(object):
158
  model_config = gin.get_configurable(network.T5Config)()
159
  module = network.Transformer(config=model_config)
160
  return models.ContinuousInputsEncoderDecoderModel(
161
- module=module,
162
- input_vocabulary=self.output_features['inputs'].vocabulary,
163
- output_vocabulary=self.output_features['targets'].vocabulary,
164
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
165
- input_depth=spectrograms.input_depth(self.spectrogram_config))
166
 
167
 
168
  def restore_from_checkpoint(self, checkpoint_path):
@@ -175,12 +175,12 @@ class InferenceModel(object):
175
  partitioner=self.partitioner)
176
 
177
  restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
178
- path=checkpoint_path, mode='specific', dtype='float32')
179
 
180
  train_state_axes = train_state_initializer.train_state_axes
181
  self._predict_fn = self._get_predict_fn(train_state_axes)
182
  self._train_state = train_state_initializer.from_checkpoint_or_scratch(
183
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
184
 
185
  @functools.lru_cache()
186
  def _get_predict_fn(self, train_state_axes):
@@ -189,11 +189,11 @@ class InferenceModel(object):
189
  def partial_predict_fn(params, batch, decode_rng):
190
  return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None})
191
  return self.partitioner.partition(
192
- partial_predict_fn,
193
- in_axis_resources=(
194
- train_state_axes.params,
195
- t5x.partitioning.PartitionSpec('data',), None),
196
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
197
  )
198
 
199
  def predict_tokens(self, batch, seed=0):
@@ -252,16 +252,16 @@ class InferenceModel(object):
252
  def preprocess(self, ds):
253
  pp_chain = [
254
  functools.partial(
255
- t5.data.preprocessors.split_tokens_to_inputs_length,
256
- sequence_length=self.sequence_length,
257
- output_features=self.output_features,
258
- feature_key='inputs',
259
- additional_feature_keys=['input_times']),
260
  # 在训练期间进行缓存。
261
  preprocessors.add_dummy_targets,
262
  functools.partial(
263
- preprocessors.compute_spectrograms,
264
- spectrogram_config=self.spectrogram_config)
265
  ]
266
  for pp in pp_chain:
267
  ds = pp(ds)
@@ -273,10 +273,10 @@ class InferenceModel(object):
273
  # 向下取整到最接近的符号化时间步。
274
  start_time -= start_time % (1 / self.codec.steps_per_second)
275
  return {
276
- 'est_tokens': tokens,
277
- 'start_time': start_time,
278
- # 内部 MT3 代码期望原始输入,这里不使用。
279
- 'raw_inputs': []
280
  }
281
 
282
  @staticmethod
@@ -308,11 +308,11 @@ article = "<p style='text-align: center'>出错了?试试把文件转换为MP3
308
  examples=[['canon.flac'], ['download.wav']]
309
 
310
  gr.Interface(
311
- inference,
312
- gr.Audio(type="filepath", label="输入"),
313
- outputs=gr.File(label="输出"),
314
- title=title,
315
- description=description,
316
- article=article,
317
- examples=examples
318
  ).launch(server_port=7861)
 
5
  from pathlib import Path
6
 
7
  def current_time():
8
+ current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
9
+ return current
10
 
11
  print(f"[{current_time()}] 开始部署空间...")
12
 
 
83
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
84
 
85
  def upload_audio(audio, sample_rate):
86
+ return note_seq.audio_io.wav_data_to_samples_librosa(
87
+ audio, sample_rate=sample_rate)
88
 
89
 
90
  print(f"[{current_time()}] 日志:开始包装模型...")
91
  class InferenceModel(object):
92
+ """音乐转录的 T5X 模型包装器。"""
93
+
94
+ def __init__(self, checkpoint_path, model_type='mt3'):
95
+ if model_type == 'ismir2021':
96
+ num_velocity_bins = 127
97
+ self.encoding_spec = note_sequences.NoteEncodingSpec
98
+ self.inputs_length = 512
99
+ elif model_type == 'mt3':
100
+ num_velocity_bins = 1
101
+ self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
102
+ self.inputs_length = 256
103
+ else:
104
+ raise ValueError('unknown model_type: %s' % model_type)
105
+
106
+ gin_files = ['/home/user/app/mt3/gin/model.gin',
107
+ '/home/user/app/mt3/gin/mt3.gin']
108
+
109
+ self.batch_size = 8
110
+ self.outputs_length = 1024
111
+ self.sequence_length = {'inputs': self.inputs_length,
112
+ 'targets': self.outputs_length}
113
+
114
+ self.partitioner = t5x.partitioning.PjitPartitioner(
115
+ model_parallel_submesh=None, num_partitions=1)
116
+
117
+ print(f"[{current_time()}] 日志:构建编解码器")
118
+ self.spectrogram_config = spectrograms.SpectrogramConfig()
119
+ self.codec = vocabularies.build_codec(
120
+ vocab_config=vocabularies.VocabularyConfig(
121
+ num_velocity_bins=num_velocity_bins)
122
+ )
123
+ self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
124
+ self.output_features = {
125
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
126
+ 'targets': seqio.Feature(vocabulary=self.vocabulary),
127
+ }
128
+
129
+ print(f"[{current_time()}] 日志:创建 T5X 模型")
130
+ self._parse_gin(gin_files)
131
+ self.model = self._load_model()
132
+
133
+ print(f"[{current_time()}] 日志:恢复模型检查点")
134
+ self.restore_from_checkpoint(checkpoint_path)
135
+
136
+ @property
137
+ def input_shapes(self):
138
  return {
139
  'encoder_input_tokens': (self.batch_size, self.inputs_length),
140
  'decoder_input_tokens': (self.batch_size, self.outputs_length)
 
144
  """解析用于训练模型的 gin 文件。"""
145
  print(f"[{current_time()}] 日志:解析 gin 文件")
146
  gin_bindings = [
147
+ 'from __gin__ import dynamic_registration',
148
+ 'from mt3 import vocabularies',
149
+ 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
150
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
151
  ]
152
  with gin.unlock_config():
153
  gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)
 
158
  model_config = gin.get_configurable(network.T5Config)()
159
  module = network.Transformer(config=model_config)
160
  return models.ContinuousInputsEncoderDecoderModel(
161
+ module=module,
162
+ input_vocabulary=self.output_features['inputs'].vocabulary,
163
+ output_vocabulary=self.output_features['targets'].vocabulary,
164
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
165
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
166
 
167
 
168
  def restore_from_checkpoint(self, checkpoint_path):
 
175
  partitioner=self.partitioner)
176
 
177
  restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
178
+ path=checkpoint_path, mode='specific', dtype='float32')
179
 
180
  train_state_axes = train_state_initializer.train_state_axes
181
  self._predict_fn = self._get_predict_fn(train_state_axes)
182
  self._train_state = train_state_initializer.from_checkpoint_or_scratch(
183
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
184
 
185
  @functools.lru_cache()
186
  def _get_predict_fn(self, train_state_axes):
 
189
  def partial_predict_fn(params, batch, decode_rng):
190
  return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None})
191
  return self.partitioner.partition(
192
+ partial_predict_fn,
193
+ in_axis_resources=(
194
+ train_state_axes.params,
195
+ t5x.partitioning.PartitionSpec('data',), None),
196
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
197
  )
198
 
199
  def predict_tokens(self, batch, seed=0):
 
252
  def preprocess(self, ds):
253
  pp_chain = [
254
  functools.partial(
255
+ t5.data.preprocessors.split_tokens_to_inputs_length,
256
+ sequence_length=self.sequence_length,
257
+ output_features=self.output_features,
258
+ feature_key='inputs',
259
+ additional_feature_keys=['input_times']),
260
  # 在训练期间进行缓存。
261
  preprocessors.add_dummy_targets,
262
  functools.partial(
263
+ preprocessors.compute_spectrograms,
264
+ spectrogram_config=self.spectrogram_config)
265
  ]
266
  for pp in pp_chain:
267
  ds = pp(ds)
 
273
  # 向下取整到最接近的符号化时间步。
274
  start_time -= start_time % (1 / self.codec.steps_per_second)
275
  return {
276
+ 'est_tokens': tokens,
277
+ 'start_time': start_time,
278
+ # 内部 MT3 代码期望原始输入,这里不使用。
279
+ 'raw_inputs': []
280
  }
281
 
282
  @staticmethod
 
308
  examples=[['canon.flac'], ['download.wav']]
309
 
310
  gr.Interface(
311
+ inference,
312
+ gr.Audio(type="filepath", label="输入"),
313
+ outputs=gr.File(label="输出"),
314
+ title=title,
315
+ description=description,
316
+ article=article,
317
+ examples=examples
318
  ).launch(server_port=7861)