Silentlin commited on
Commit
493b8bb
1 Parent(s): b247641

fix fs config

Browse files
egs/egs_bases/tts/fs.yaml CHANGED
@@ -36,10 +36,10 @@ dec_post_net_kernel: 3
36
 
37
  # duration
38
  predictor_hidden: -1
39
- predictor_kernel: 5
40
- predictor_layers: 2
41
  dur_predictor_kernel: 3
42
  dur_predictor_layers: 2
 
 
43
  predictor_dropout: 0.5
44
 
45
  # pitch and energy
 
36
 
37
  # duration
38
  predictor_hidden: -1
 
 
39
  dur_predictor_kernel: 3
40
  dur_predictor_layers: 2
41
+ predictor_kernel: 5
42
+ predictor_layers: 5
43
  predictor_dropout: 0.5
44
 
45
  # pitch and energy
egs/egs_bases/tts/fs2_orig.yaml CHANGED
@@ -2,10 +2,12 @@ base_config: ./fs.yaml
2
  task_cls: tasks.tts.fs2_orig.FastSpeech2OrigTask
3
  encoder_type: fft
4
  decoder_type: fft
5
- use_energy_embed: true
6
  use_pitch_embed: true
7
  pitch_type: cwt # frame|ph|cwt
8
  binarization_args:
9
  with_f0cwt: true
10
  use_gt_energy: false
11
- cwt_std_scale: 1.0
 
 
 
2
  task_cls: tasks.tts.fs2_orig.FastSpeech2OrigTask
3
  encoder_type: fft
4
  decoder_type: fft
5
+ use_energy_embed: false
6
  use_pitch_embed: true
7
  pitch_type: cwt # frame|ph|cwt
8
  binarization_args:
9
  with_f0cwt: true
10
  use_gt_energy: false
11
+ cwt_std_scale: 0.8
12
+ dropout: 0.1
13
+ mel_losses: l1
modules/tts/fs2_orig.py CHANGED
@@ -17,12 +17,12 @@ class FastSpeech2Orig(FastSpeech):
17
  self.energy_embed = Embedding(300, self.hidden_size, 0)
18
  self.energy_predictor = EnergyPredictor(
19
  self.hidden_size, n_chans=predictor_hidden,
20
- n_layers=5, dropout_rate=0.1, odim=2,
21
  kernel_size=hparams['predictor_kernel'])
22
  if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']:
23
  self.pitch_predictor = PitchPredictor(
24
  self.hidden_size, n_chans=predictor_hidden,
25
- n_layers=5, dropout_rate=0.1, odim=11,
26
  kernel_size=hparams['predictor_kernel'])
27
  self.cwt_stats_layers = nn.Sequential(
28
  nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
@@ -67,7 +67,7 @@ class FastSpeech2Orig(FastSpeech):
67
  decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
68
  pitch_padding = mel2ph == 0
69
  ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp)
70
- stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
71
  mean = ret['f0_mean'] = stats_out[:, 0]
72
  std = ret['f0_std'] = stats_out[:, 1]
73
  cwt_spec = cwt_out[:, :, :10]
 
17
  self.energy_embed = Embedding(300, self.hidden_size, 0)
18
  self.energy_predictor = EnergyPredictor(
19
  self.hidden_size, n_chans=predictor_hidden,
20
+ n_layers=hparams['predictor_layers'], dropout_rate=hparams['predictor_dropout'], odim=2,
21
  kernel_size=hparams['predictor_kernel'])
22
  if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']:
23
  self.pitch_predictor = PitchPredictor(
24
  self.hidden_size, n_chans=predictor_hidden,
25
+ n_layers=hparams['predictor_layers'], dropout_rate=hparams['predictor_dropout'], odim=11,
26
  kernel_size=hparams['predictor_kernel'])
27
  self.cwt_stats_layers = nn.Sequential(
28
  nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
 
67
  decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
68
  pitch_padding = mel2ph == 0
69
  ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp)
70
+ stats_out = self.cwt_stats_layers(decoder_inp.mean(1)) # [B, 2]
71
  mean = ret['f0_mean'] = stats_out[:, 0]
72
  std = ret['f0_std'] = stats_out[:, 1]
73
  cwt_spec = cwt_out[:, :, :10]
tasks/tts/fs2_orig.py CHANGED
@@ -5,6 +5,8 @@ from tasks.tts.dataset_utils import FastSpeechDataset
5
  from tasks.tts.fs import FastSpeechTask
6
  from utils.commons.dataset_utils import collate_1d, collate_2d
7
  from utils.commons.hparams import hparams
 
 
8
 
9
 
10
  class FastSpeech2OrigDataset(FastSpeechDataset):
@@ -88,6 +90,24 @@ class FastSpeech2OrigTask(FastSpeechTask):
88
  f0=f0, uv=uv, energy=energy, infer=True)
89
  return output
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def add_pitch_loss(self, output, sample, losses):
92
  if hparams['pitch_type'] == 'cwt':
93
  cwt_spec = sample[f'cwt_spec']
 
5
  from tasks.tts.fs import FastSpeechTask
6
  from utils.commons.dataset_utils import collate_1d, collate_2d
7
  from utils.commons.hparams import hparams
8
+ from utils.plot.plot import spec_to_figure
9
+ import numpy as np
10
 
11
 
12
  class FastSpeech2OrigDataset(FastSpeechDataset):
 
90
  f0=f0, uv=uv, energy=energy, infer=True)
91
  return output
92
 
93
+ def save_valid_result(self, sample, batch_idx, model_out):
94
+ super(FastSpeech2OrigTask, self).save_valid_result(sample, batch_idx, model_out)
95
+ self.plot_cwt(batch_idx, model_out['cwt'], sample['cwt_spec'])
96
+
97
+ def plot_cwt(self, batch_idx, cwt_out, cwt_gt=None):
98
+ if len(cwt_out.shape) == 3:
99
+ cwt_out = cwt_out[0]
100
+ if isinstance(cwt_out, torch.Tensor):
101
+ cwt_out = cwt_out.cpu().numpy()
102
+ if cwt_gt is not None:
103
+ if len(cwt_gt.shape) == 3:
104
+ cwt_gt = cwt_gt[0]
105
+ if isinstance(cwt_gt, torch.Tensor):
106
+ cwt_gt = cwt_gt.cpu().numpy()
107
+ cwt_out = np.concatenate([cwt_out, cwt_gt], -1)
108
+ name = f'cwt_val_{batch_idx}'
109
+ self.logger.add_figure(name, spec_to_figure(cwt_out), self.global_step)
110
+
111
  def add_pitch_loss(self, output, sample, losses):
112
  if hparams['pitch_type'] == 'cwt':
113
  cwt_spec = sample[f'cwt_spec']