Silentlin commited on
Commit
b247641
1 Parent(s): b7ad36b

fix ds ckpt

Browse files
checkpoints/ds_exp/config.yaml ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ K_step: 71
2
+ accumulate_grad_batches: 1
3
+ amp: false
4
+ audio_num_mel_bins: 80
5
+ audio_sample_rate: 22050
6
+ base_config:
7
+ - egs/egs_bases/tts/ds.yaml
8
+ - ./fs2_orig.yaml
9
+ binarization_args:
10
+ min_sil_duration: 0.1
11
+ shuffle: false
12
+ test_range:
13
+ - 0
14
+ - 523
15
+ train_range:
16
+ - 871
17
+ - -1
18
+ trim_eos_bos: false
19
+ valid_range:
20
+ - 523
21
+ - 871
22
+ with_align: true
23
+ with_f0: true
24
+ with_f0cwt: true
25
+ with_linear: false
26
+ with_spk_embed: false
27
+ with_wav: false
28
+ binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
29
+ binary_data_dir: data/binary/ljspeech_cwt
30
+ check_val_every_n_epoch: 10
31
+ clip_grad_norm: 1
32
+ clip_grad_value: 0
33
+ conv_use_pos: false
34
+ cwt_std_scale: 0.8
35
+ debug: false
36
+ dec_dilations:
37
+ - 1
38
+ - 1
39
+ - 1
40
+ - 1
41
+ dec_ffn_kernel_size: 9
42
+ dec_inp_add_noise: false
43
+ dec_kernel_size: 5
44
+ dec_layers: 4
45
+ dec_post_net_kernel: 3
46
+ decay_steps: 50000
47
+ decoder_rnn_dim: 0
48
+ decoder_type: fft
49
+ diff_decoder_type: wavenet
50
+ diff_loss_type: l1
51
+ dilation_cycle_length: 1
52
+ dropout: 0.1
53
+ ds_workers: 2
54
+ dur_predictor_kernel: 3
55
+ dur_predictor_layers: 2
56
+ enc_dec_norm: ln
57
+ enc_dilations:
58
+ - 1
59
+ - 1
60
+ - 1
61
+ - 1
62
+ enc_ffn_kernel_size: 9
63
+ enc_kernel_size: 5
64
+ enc_layers: 4
65
+ enc_post_net_kernel: 3
66
+ enc_pre_ln: true
67
+ enc_prenet: true
68
+ encoder_K: 8
69
+ encoder_type: fft
70
+ endless_ds: true
71
+ eval_max_batches: -1
72
+ f0_max: 600
73
+ f0_min: 80
74
+ ffn_act: gelu
75
+ ffn_hidden_size: 1024
76
+ fft_size: 1024
77
+ fmax: 7600
78
+ fmin: 80
79
+ frames_multiple: 1
80
+ fs2_ckpt: checkpoints/fs2orig_exp/model_ckpt_steps_100000.ckpt
81
+ gen_dir_name: ''
82
+ griffin_lim_iters: 30
83
+ hidden_size: 256
84
+ hop_size: 256
85
+ infer: false
86
+ keep_bins: 80
87
+ lambda_commit: 0.25
88
+ lambda_energy: 0.1
89
+ lambda_f0: 1.0
90
+ lambda_ph_dur: 0.1
91
+ lambda_sent_dur: 1.0
92
+ lambda_uv: 1.0
93
+ lambda_word_dur: 1.0
94
+ layers_in_block: 2
95
+ load_ckpt: ''
96
+ loud_norm: false
97
+ lr: 0.001
98
+ max_beta: 0.06
99
+ max_epochs: 1000
100
+ max_frames: 1548
101
+ max_input_tokens: 1550
102
+ max_sentences: 128
103
+ max_tokens: 30000
104
+ max_updates: 160000
105
+ max_valid_sentences: 1
106
+ max_valid_tokens: 60000
107
+ mel_losses: l1
108
+ mel_vmax: 1.5
109
+ mel_vmin: -6
110
+ min_frames: 0
111
+ num_ckpt_keep: 3
112
+ num_heads: 2
113
+ num_sanity_val_steps: 5
114
+ num_spk: 1
115
+ num_valid_plots: 10
116
+ optimizer_adam_beta1: 0.9
117
+ optimizer_adam_beta2: 0.98
118
+ out_wav_norm: false
119
+ pitch_extractor: parselmouth
120
+ pitch_key: pitch
121
+ pitch_type: cwt
122
+ predictor_dropout: 0.5
123
+ predictor_grad: 0.1
124
+ predictor_hidden: -1
125
+ predictor_kernel: 5
126
+ predictor_layers: 5
127
+ preprocess_args:
128
+ add_eos_bos: true
129
+ mfa_group_shuffle: false
130
+ mfa_offset: 0.02
131
+ nsample_per_mfa_group: 1000
132
+ reset_phone_dict: true
133
+ reset_word_dict: true
134
+ save_sil_mask: true
135
+ txt_processor: en
136
+ use_mfa: true
137
+ vad_max_silence_length: 12
138
+ wav_processors: []
139
+ with_phsep: true
140
+ preprocess_cls: egs.datasets.audio.lj.preprocess.LJPreprocess
141
+ print_nan_grads: false
142
+ processed_data_dir: data/processed/ljspeech
143
+ profile_infer: false
144
+ raw_data_dir: data/raw/LJSpeech-1.1
145
+ ref_norm_layer: bn
146
+ rename_tmux: true
147
+ residual_channels: 256
148
+ residual_layers: 20
149
+ resume_from_checkpoint: 0
150
+ save_best: false
151
+ save_codes:
152
+ - tasks
153
+ - modules
154
+ - egs
155
+ save_f0: false
156
+ save_gt: true
157
+ schedule_type: linear
158
+ scheduler: warmup
159
+ seed: 1234
160
+ sort_by_len: true
161
+ spec_max:
162
+ - -0.5982
163
+ - -0.0778
164
+ - 0.1205
165
+ - 0.2747
166
+ - 0.4657
167
+ - 0.5123
168
+ - 0.583
169
+ - 0.7093
170
+ - 0.6461
171
+ - 0.6101
172
+ - 0.7316
173
+ - 0.7715
174
+ - 0.7681
175
+ - 0.8349
176
+ - 0.7815
177
+ - 0.7591
178
+ - 0.791
179
+ - 0.7433
180
+ - 0.7352
181
+ - 0.6869
182
+ - 0.6854
183
+ - 0.6623
184
+ - 0.5353
185
+ - 0.6492
186
+ - 0.6909
187
+ - 0.6106
188
+ - 0.5761
189
+ - 0.5236
190
+ - 0.5638
191
+ - 0.4054
192
+ - 0.4545
193
+ - 0.3407
194
+ - 0.3037
195
+ - 0.338
196
+ - 0.1599
197
+ - 0.1603
198
+ - 0.2741
199
+ - 0.213
200
+ - 0.1569
201
+ - 0.1911
202
+ - 0.2324
203
+ - 0.1586
204
+ - 0.1221
205
+ - 0.0341
206
+ - -0.0558
207
+ - 0.0553
208
+ - -0.1153
209
+ - -0.0933
210
+ - -0.1171
211
+ - -0.005
212
+ - -0.1519
213
+ - -0.1629
214
+ - -0.0522
215
+ - -0.0739
216
+ - -0.2069
217
+ - -0.2405
218
+ - -0.1244
219
+ - -0.2582
220
+ - -0.1361
221
+ - -0.1575
222
+ - -0.1442
223
+ - 0.0513
224
+ - -0.1567
225
+ - -0.2
226
+ - 0.0086
227
+ - -0.0698
228
+ - 0.1385
229
+ - 0.0941
230
+ - 0.1864
231
+ - 0.1225
232
+ - 0.1389
233
+ - 0.1382
234
+ - 0.167
235
+ - 0.1007
236
+ - 0.1444
237
+ - 0.0888
238
+ - 0.1998
239
+ - 0.228
240
+ - 0.2932
241
+ - 0.3047
242
+ spec_min:
243
+ - -4.7574
244
+ - -4.6783
245
+ - -4.6431
246
+ - -4.5832
247
+ - -4.539
248
+ - -4.6771
249
+ - -4.8089
250
+ - -4.7672
251
+ - -4.5784
252
+ - -4.7755
253
+ - -4.715
254
+ - -4.8919
255
+ - -4.8271
256
+ - -4.7389
257
+ - -4.6047
258
+ - -4.7759
259
+ - -4.6799
260
+ - -4.8201
261
+ - -4.7823
262
+ - -4.8262
263
+ - -4.7857
264
+ - -4.7545
265
+ - -4.9358
266
+ - -4.9733
267
+ - -5.1134
268
+ - -5.1395
269
+ - -4.9016
270
+ - -4.8434
271
+ - -5.0189
272
+ - -4.846
273
+ - -5.0529
274
+ - -4.951
275
+ - -5.0217
276
+ - -5.0049
277
+ - -5.1831
278
+ - -5.1445
279
+ - -5.1015
280
+ - -5.0281
281
+ - -4.9887
282
+ - -4.9916
283
+ - -4.9785
284
+ - -4.9071
285
+ - -4.9488
286
+ - -5.0342
287
+ - -4.9332
288
+ - -5.065
289
+ - -4.8924
290
+ - -5.0875
291
+ - -5.0483
292
+ - -5.0848
293
+ - -5.0655
294
+ - -5.0279
295
+ - -5.0015
296
+ - -5.0792
297
+ - -5.0636
298
+ - -5.2413
299
+ - -5.1421
300
+ - -5.171
301
+ - -5.3256
302
+ - -5.0511
303
+ - -5.1186
304
+ - -5.0057
305
+ - -5.0446
306
+ - -5.1173
307
+ - -5.0325
308
+ - -5.1085
309
+ - -5.0053
310
+ - -5.0755
311
+ - -5.1176
312
+ - -5.1004
313
+ - -5.2153
314
+ - -5.2757
315
+ - -5.3025
316
+ - -5.2867
317
+ - -5.2918
318
+ - -5.3328
319
+ - -5.2731
320
+ - -5.2985
321
+ - -5.24
322
+ - -5.2211
323
+ task_cls: tasks.tts.diffspeech.DiffSpeechTask
324
+ tb_log_interval: 100
325
+ test_ids:
326
+ - 0
327
+ - 1
328
+ - 2
329
+ - 3
330
+ - 4
331
+ - 5
332
+ - 6
333
+ - 7
334
+ - 8
335
+ - 9
336
+ - 10
337
+ - 11
338
+ - 12
339
+ - 13
340
+ - 14
341
+ - 15
342
+ - 16
343
+ - 17
344
+ - 18
345
+ - 19
346
+ - 68
347
+ - 70
348
+ - 74
349
+ - 87
350
+ - 110
351
+ - 172
352
+ - 190
353
+ - 215
354
+ - 231
355
+ - 294
356
+ - 316
357
+ - 324
358
+ - 402
359
+ - 422
360
+ - 485
361
+ - 500
362
+ - 505
363
+ - 508
364
+ - 509
365
+ - 519
366
+ test_input_yaml: ''
367
+ test_num: 100
368
+ test_set_name: test
369
+ timesteps: 100
370
+ train_set_name: train
371
+ train_sets: ''
372
+ use_energy_embed: false
373
+ use_gt_dur: false
374
+ use_gt_energy: false
375
+ use_gt_f0: false
376
+ use_pitch_embed: true
377
+ use_pos_embed: true
378
+ use_spk_embed: false
379
+ use_spk_id: false
380
+ use_uv: true
381
+ use_word_input: false
382
+ val_check_interval: 2000
383
+ valid_infer_interval: 10000
384
+ valid_monitor_key: val_loss
385
+ valid_monitor_mode: min
386
+ valid_set_name: valid
387
+ vocoder: HifiGAN
388
+ vocoder_ckpt: checkpoints/hifi_lj
389
+ warmup_updates: 4000
390
+ weight_decay: 0
391
+ win_size: 1024
392
+ word_dict_size: 10000
393
+ work_dir: checkpoints/0214_ds
checkpoints/ds_exp/model_ckpt_steps_160000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b41f14a72d74c706ffcda7ae319c3ad873cdc4313a8be47ed117b0c67e7c9aac
3
+ size 161932032
egs/datasets/audio/lj/ds.yaml CHANGED
@@ -2,7 +2,7 @@ base_config:
2
  - egs/egs_bases/tts/ds.yaml
3
  - ./fs2_orig.yaml
4
 
5
- fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
6
 
7
  # spec_min and spec_max are calculated on the training set.
8
  spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672,
 
2
  - egs/egs_bases/tts/ds.yaml
3
  - ./fs2_orig.yaml
4
 
5
+ fs2_ckpt: checkpoints/aux_exp/model_ckpt_steps_100000.ckpt
6
 
7
  # spec_min and spec_max are calculated on the training set.
8
  spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672,
egs/egs_bases/tts/ds.yaml CHANGED
@@ -28,5 +28,6 @@ keep_bins: 80
28
  num_valid_plots: 10
29
  use_gt_dur: false
30
  use_gt_f0: false
 
31
  #pitch_type: cwt
32
  max_updates: 160000
 
28
  num_valid_plots: 10
29
  use_gt_dur: false
30
  use_gt_f0: false
31
+ use_energy_embed: false
32
  #pitch_type: cwt
33
  max_updates: 160000
modules/tts/diffspeech/shallow_diffusion_tts.py CHANGED
@@ -236,8 +236,6 @@ class GaussianDiffusion(nn.Module):
236
  b, *_, device = *txt_tokens.shape, txt_tokens.device
237
  ret = self.fs2(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
238
  f0=f0, uv=uv, energy=energy, infer=infer, skip_decoder=(not infer), **kwargs)
239
- # (txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
240
- # skip_decoder=(not infer), infer=infer, **kwargs)
241
  cond = ret['decoder_inp'].transpose(1, 2)
242
 
243
  if not infer:
 
236
  b, *_, device = *txt_tokens.shape, txt_tokens.device
237
  ret = self.fs2(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
238
  f0=f0, uv=uv, energy=energy, infer=infer, skip_decoder=(not infer), **kwargs)
 
 
239
  cond = ret['decoder_inp'].transpose(1, 2)
240
 
241
  if not infer:
tasks/tts/diffspeech.py CHANGED
@@ -26,12 +26,12 @@ class DiffSpeechTask(FastSpeech2OrigTask):
26
  self.model = GaussianDiffusion(dict_size, hparams)
27
  if hparams['fs2_ckpt'] != '':
28
  load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
29
- for k, v in self.model.fs2.named_parameters():
30
- if 'predictor' not in k:
31
- v.requires_grad = False
32
- # or
33
  # for k, v in self.model.fs2.named_parameters():
34
- # v.requires_grad = False
 
 
 
 
35
 
36
  def build_optimizer(self, model):
37
  self.optimizer = optimizer = torch.optim.AdamW(
 
26
  self.model = GaussianDiffusion(dict_size, hparams)
27
  if hparams['fs2_ckpt'] != '':
28
  load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
 
 
 
 
29
  # for k, v in self.model.fs2.named_parameters():
30
+ # if 'predictor' not in k:
31
+ # v.requires_grad = False
32
+ # or
33
+ for k, v in self.model.fs2.named_parameters():
34
+ v.requires_grad = False
35
 
36
  def build_optimizer(self, model):
37
  self.optimizer = optimizer = torch.optim.AdamW(