yuancwang commited on
Commit
b725c5a
1 Parent(s): 3e8a9fc
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +61 -0
  2. app.py +31 -0
  3. config/audioldm.json +92 -0
  4. config/autoencoderkl.json +69 -0
  5. config/base.json +220 -0
  6. config/comosvc.json +216 -0
  7. config/diffusion.json +227 -0
  8. config/fs2.json +118 -0
  9. config/ns2.json +88 -0
  10. config/transformer.json +180 -0
  11. config/tts.json +23 -0
  12. config/valle.json +53 -0
  13. config/vits.json +101 -0
  14. config/vitssvc.json +192 -0
  15. config/vocoder.json +84 -0
  16. evaluation/__init__.py +0 -0
  17. evaluation/features/__init__.py +0 -0
  18. evaluation/features/long_term_average_spectrum.py +19 -0
  19. evaluation/features/signal_to_noise_ratio.py +133 -0
  20. evaluation/features/singing_power_ratio.py +108 -0
  21. evaluation/metrics/__init__.py +0 -0
  22. evaluation/metrics/energy/__init__.py +0 -0
  23. evaluation/metrics/energy/energy_pearson_coefficients.py +91 -0
  24. evaluation/metrics/energy/energy_rmse.py +86 -0
  25. evaluation/metrics/f0/__init__.py +0 -0
  26. evaluation/metrics/f0/f0_pearson_coefficients.py +111 -0
  27. evaluation/metrics/f0/f0_periodicity_rmse.py +112 -0
  28. evaluation/metrics/f0/f0_rmse.py +110 -0
  29. evaluation/metrics/f0/v_uv_f1.py +110 -0
  30. evaluation/metrics/intelligibility/__init__.py +0 -0
  31. evaluation/metrics/intelligibility/character_error_rate.py +81 -0
  32. evaluation/metrics/intelligibility/word_error_rate.py +81 -0
  33. evaluation/metrics/similarity/__init__.py +0 -0
  34. evaluation/metrics/similarity/models/RawNetBasicBlock.py +146 -0
  35. evaluation/metrics/similarity/models/RawNetModel.py +142 -0
  36. evaluation/metrics/similarity/models/__init__.py +0 -0
  37. evaluation/metrics/similarity/speaker_similarity.py +119 -0
  38. evaluation/metrics/spectrogram/__init__.py +0 -0
  39. evaluation/metrics/spectrogram/frechet_distance.py +31 -0
  40. evaluation/metrics/spectrogram/mel_cepstral_distortion.py +21 -0
  41. evaluation/metrics/spectrogram/multi_resolution_stft_distance.py +225 -0
  42. evaluation/metrics/spectrogram/pesq.py +56 -0
  43. evaluation/metrics/spectrogram/scale_invariant_signal_to_distortion_ratio.py +45 -0
  44. evaluation/metrics/spectrogram/scale_invariant_signal_to_noise_ratio.py +45 -0
  45. evaluation/metrics/spectrogram/short_time_objective_intelligibility.py +56 -0
  46. models/tts/base/__init__.py +7 -0
  47. models/tts/base/tts_dataset.py +389 -0
  48. models/tts/base/tts_inferece.py +268 -0
  49. models/tts/base/tts_trainer.py +699 -0
  50. models/tts/fastspeech2/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mac OS files
2
+ .DS_Store
3
+
4
+ # IDEs
5
+ .idea
6
+ .vs
7
+ .vscode
8
+ .cache
9
+
10
+ # GitHub files
11
+ .github
12
+
13
+ # Byte-compiled / optimized / DLL / cached files
14
+ __pycache__/
15
+ *.py[cod]
16
+ *$py.class
17
+ *.pyc
18
+ .temp
19
+ *.c
20
+ *.so
21
+ *.o
22
+
23
+ # Developing mode
24
+ _*.sh
25
+ _*.json
26
+ *.lst
27
+ yard*
28
+ *.out
29
+ evaluation/evalset_selection
30
+ mfa
31
+ egs/svc/*wavmark
32
+ egs/svc/custom
33
+ egs/svc/*/dev*
34
+ egs/svc/dev_exp_config.json
35
+ bins/svc/demo*
36
+ bins/svc/preprocess_custom.py
37
+ data
38
+
39
+ # Data and ckpt
40
+ *.pkl
41
+ *.pt
42
+ *.npy
43
+ *.npz
44
+ !modules/whisper_extractor/assets/mel_filters.npz
45
+ *.tar.gz
46
+ *.ckpt
47
+ *.wav
48
+ *.flac
49
+ pretrained/wenet/*conformer_exp
50
+
51
+ # Runtime data dirs
52
+ processed_data
53
+ data
54
+ model_ckpt
55
+ logs
56
+ *.ipynb
57
+ *.lst
58
+ source_audio
59
+ result
60
+ conversion_results
61
+ get_available_gpu.py
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+
6
+
7
+ def build_codec():
8
+ ...
9
+
10
+ def build_model():
11
+ ...
12
+
13
+ def ns2_inference(
14
+ prmopt_audio_path,
15
+ text,
16
+ diffusion_steps=100,
17
+ ):
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ demo_inputs = ...
21
+ demo_outputs = ...
22
+
23
+ demo = gr.Interface(
24
+ fn=ns2_inference,
25
+ inputs=demo_inputs,
26
+ outputs=demo_outputs,
27
+ title="Amphion Zero-Shot TTS NaturalSpeech2"
28
+ )
29
+
30
+ if __name__ == "__main__":
31
+ demo.launch()
config/audioldm.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "AudioLDM",
4
+ "task_type": "tta",
5
+ "dataset": [
6
+ "AudioCaps"
7
+ ],
8
+ "preprocess": {
9
+ // feature used for model training
10
+ "use_spkid": false,
11
+ "use_uv": false,
12
+ "use_frame_pitch": false,
13
+ "use_phone_pitch": false,
14
+ "use_frame_energy": false,
15
+ "use_phone_energy": false,
16
+ "use_mel": false,
17
+ "use_audio": false,
18
+ "use_label": false,
19
+ "use_one_hot": false,
20
+ "cond_mask_prob": 0.1
21
+ },
22
+ // model
23
+ "model": {
24
+ "audioldm": {
25
+ "image_size": 32,
26
+ "in_channels": 4,
27
+ "out_channels": 4,
28
+ "model_channels": 256,
29
+ "attention_resolutions": [
30
+ 4,
31
+ 2,
32
+ 1
33
+ ],
34
+ "num_res_blocks": 2,
35
+ "channel_mult": [
36
+ 1,
37
+ 2,
38
+ 4
39
+ ],
40
+ "num_heads": 8,
41
+ "use_spatial_transformer": true,
42
+ "transformer_depth": 1,
43
+ "context_dim": 768,
44
+ "use_checkpoint": true,
45
+ "legacy": false
46
+ },
47
+ "autoencoderkl": {
48
+ "ch": 128,
49
+ "ch_mult": [
50
+ 1,
51
+ 1,
52
+ 2,
53
+ 2,
54
+ 4
55
+ ],
56
+ "num_res_blocks": 2,
57
+ "in_channels": 1,
58
+ "z_channels": 4,
59
+ "out_ch": 1,
60
+ "double_z": true
61
+ },
62
+ "noise_scheduler": {
63
+ "num_train_timesteps": 1000,
64
+ "beta_start": 0.00085,
65
+ "beta_end": 0.012,
66
+ "beta_schedule": "scaled_linear",
67
+ "clip_sample": false,
68
+ "steps_offset": 1,
69
+ "set_alpha_to_one": false,
70
+ "skip_prk_steps": true,
71
+ "prediction_type": "epsilon"
72
+ }
73
+ },
74
+ // train
75
+ "train": {
76
+ "lronPlateau": {
77
+ "factor": 0.9,
78
+ "patience": 100,
79
+ "min_lr": 4.0e-5,
80
+ "verbose": true
81
+ },
82
+ "adam": {
83
+ "lr": 5.0e-5,
84
+ "betas": [
85
+ 0.9,
86
+ 0.999
87
+ ],
88
+ "weight_decay": 1.0e-2,
89
+ "eps": 1.0e-8
90
+ }
91
+ }
92
+ }
config/autoencoderkl.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "AutoencoderKL",
4
+ "task_type": "tta",
5
+ "dataset": [
6
+ "AudioCaps"
7
+ ],
8
+ "preprocess": {
9
+ // feature used for model training
10
+ "use_spkid": false,
11
+ "use_uv": false,
12
+ "use_frame_pitch": false,
13
+ "use_phone_pitch": false,
14
+ "use_frame_energy": false,
15
+ "use_phone_energy": false,
16
+ "use_mel": false,
17
+ "use_audio": false,
18
+ "use_label": false,
19
+ "use_one_hot": false
20
+ },
21
+ // model
22
+ "model": {
23
+ "autoencoderkl": {
24
+ "ch": 128,
25
+ "ch_mult": [
26
+ 1,
27
+ 1,
28
+ 2,
29
+ 2,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "in_channels": 1,
34
+ "z_channels": 4,
35
+ "out_ch": 1,
36
+ "double_z": true
37
+ },
38
+ "loss": {
39
+ "kl_weight": 1e-8,
40
+ "disc_weight": 0.5,
41
+ "disc_factor": 1.0,
42
+ "logvar_init": 0.0,
43
+ "min_adapt_d_weight": 0.0,
44
+ "max_adapt_d_weight": 10.0,
45
+ "disc_start": 50001,
46
+ "disc_in_channels": 1,
47
+ "disc_num_layers": 3,
48
+ "use_actnorm": false
49
+ }
50
+ },
51
+ // train
52
+ "train": {
53
+ "lronPlateau": {
54
+ "factor": 0.9,
55
+ "patience": 100,
56
+ "min_lr": 4.0e-5,
57
+ "verbose": true
58
+ },
59
+ "adam": {
60
+ "lr": 4.0e-4,
61
+ "betas": [
62
+ 0.9,
63
+ 0.999
64
+ ],
65
+ "weight_decay": 1.0e-2,
66
+ "eps": 1.0e-8
67
+ }
68
+ }
69
+ }
config/base.json ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "supported_model_type": [
3
+ "GANVocoder",
4
+ "Fastspeech2",
5
+ "DiffSVC",
6
+ "Transformer",
7
+ "EDM",
8
+ "CD"
9
+ ],
10
+ "task_type": "",
11
+ "dataset": [],
12
+ "use_custom_dataset": false,
13
+ "preprocess": {
14
+ "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon"
15
+ // trim audio silence
16
+ "data_augment": false,
17
+ "trim_silence": false,
18
+ "num_silent_frames": 8,
19
+ "trim_fft_size": 512, // fft size used in trimming
20
+ "trim_hop_size": 128, // hop size used in trimming
21
+ "trim_top_db": 30, // top db used in trimming sensitive to each dataset
22
+ // acoustic features
23
+ "extract_mel": false,
24
+ "mel_extract_mode": "",
25
+ "extract_linear_spec": false,
26
+ "extract_mcep": false,
27
+ "extract_pitch": false,
28
+ "extract_acoustic_token": false,
29
+ "pitch_remove_outlier": false,
30
+ "extract_uv": false,
31
+ "pitch_norm": false,
32
+ "extract_audio": false,
33
+ "extract_label": false,
34
+ "pitch_extractor": "parselmouth", // pyin, dio, pyworld, pyreaper, parselmouth, CWT (Continuous Wavelet Transform)
35
+ "extract_energy": false,
36
+ "energy_remove_outlier": false,
37
+ "energy_norm": false,
38
+ "energy_extract_mode": "from_mel",
39
+ "extract_duration": false,
40
+ "extract_amplitude_phase": false,
41
+ "mel_min_max_norm": false,
42
+ // lingusitic features
43
+ "extract_phone": false,
44
+ "lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
45
+ // content features
46
+ "extract_whisper_feature": false,
47
+ "extract_contentvec_feature": false,
48
+ "extract_mert_feature": false,
49
+ "extract_wenet_feature": false,
50
+ // Settings for data preprocessing
51
+ "n_mel": 80,
52
+ "win_size": 480,
53
+ "hop_size": 120,
54
+ "sample_rate": 24000,
55
+ "n_fft": 1024,
56
+ "fmin": 0,
57
+ "fmax": 12000,
58
+ "min_level_db": -115,
59
+ "ref_level_db": 20,
60
+ "bits": 8,
61
+ // Directory names of processed data or extracted features
62
+ "processed_dir": "processed_data",
63
+ "trimmed_wav_dir": "trimmed_wavs", // directory name of silence trimed wav
64
+ "raw_data": "raw_data",
65
+ "phone_dir": "phones",
66
+ "wav_dir": "wavs", // directory name of processed wav (such as downsampled waveform)
67
+ "audio_dir": "audios",
68
+ "log_amplitude_dir": "log_amplitudes",
69
+ "phase_dir": "phases",
70
+ "real_dir": "reals",
71
+ "imaginary_dir": "imaginarys",
72
+ "label_dir": "labels",
73
+ "linear_dir": "linears",
74
+ "mel_dir": "mels", // directory name of extraced mel features
75
+ "mcep_dir": "mcep", // directory name of extraced mcep features
76
+ "dur_dir": "durs",
77
+ "symbols_dict": "symbols.dict",
78
+ "lab_dir": "labs", // directory name of extraced label features
79
+ "wenet_dir": "wenet", // directory name of extraced wenet features
80
+ "contentvec_dir": "contentvec", // directory name of extraced wenet features
81
+ "pitch_dir": "pitches", // directory name of extraced pitch features
82
+ "energy_dir": "energys", // directory name of extracted energy features
83
+ "phone_pitch_dir": "phone_pitches", // directory name of extraced pitch features
84
+ "phone_energy_dir": "phone_energys", // directory name of extracted energy features
85
+ "uv_dir": "uvs", // directory name of extracted unvoiced features
86
+ "duration_dir": "duration", // ground-truth duration file
87
+ "phone_seq_file": "phone_seq_file", // phoneme sequence file
88
+ "file_lst": "file.lst",
89
+ "train_file": "train.json", // training set, the json file contains detailed information about the dataset, including dataset name, utterance id, duration of the utterance
90
+ "valid_file": "valid.json", // validattion set
91
+ "spk2id": "spk2id.json", // used for multi-speaker dataset
92
+ "utt2spk": "utt2spk", // used for multi-speaker dataset
93
+ "emo2id": "emo2id.json", // used for multi-emotion dataset
94
+ "utt2emo": "utt2emo", // used for multi-emotion dataset
95
+ // Features used for model training
96
+ "use_text": false,
97
+ "use_phone": false,
98
+ "use_phn_seq": false,
99
+ "use_lab": false,
100
+ "use_linear": false,
101
+ "use_mel": false,
102
+ "use_min_max_norm_mel": false,
103
+ "use_wav": false,
104
+ "use_phone_pitch": false,
105
+ "use_log_scale_pitch": false,
106
+ "use_phone_energy": false,
107
+ "use_phone_duration": false,
108
+ "use_log_scale_energy": false,
109
+ "use_wenet": false,
110
+ "use_dur": false,
111
+ "use_spkid": false, // True: use speaker id for multi-speaker dataset
112
+ "use_emoid": false, // True: use emotion id for multi-emotion dataset
113
+ "use_frame_pitch": false,
114
+ "use_uv": false,
115
+ "use_frame_energy": false,
116
+ "use_frame_duration": false,
117
+ "use_audio": false,
118
+ "use_label": false,
119
+ "use_one_hot": false,
120
+ "use_amplitude_phase": false,
121
+ "data_augment": false,
122
+ "align_mel_duration": false
123
+ },
124
+ "train": {
125
+ "ddp": true,
126
+ "random_seed": 970227,
127
+ "batch_size": 16,
128
+ "max_steps": 1000000,
129
+ // Trackers
130
+ "tracker": [
131
+ "tensorboard"
132
+ // "wandb",
133
+ // "cometml",
134
+ // "mlflow",
135
+ ],
136
+ "max_epoch": -1,
137
+ // -1 means no limit
138
+ "save_checkpoint_stride": [
139
+ 5,
140
+ 20
141
+ ],
142
+ // unit is epoch
143
+ "keep_last": [
144
+ 3,
145
+ -1
146
+ ],
147
+ // -1 means infinite, if one number will broadcast
148
+ "run_eval": [
149
+ false,
150
+ true
151
+ ],
152
+ // if one number will broadcast
153
+ // Fix the random seed
154
+ "random_seed": 10086,
155
+ // Optimizer
156
+ "optimizer": "AdamW",
157
+ "adamw": {
158
+ "lr": 4.0e-4
159
+ // nn model lr
160
+ },
161
+ // LR Scheduler
162
+ "scheduler": "ReduceLROnPlateau",
163
+ "reducelronplateau": {
164
+ "factor": 0.8,
165
+ "patience": 10,
166
+ // unit is epoch
167
+ "min_lr": 1.0e-4
168
+ },
169
+ // Batchsampler
170
+ "sampler": {
171
+ "holistic_shuffle": true,
172
+ "drop_last": true
173
+ },
174
+ // Dataloader
175
+ "dataloader": {
176
+ "num_worker": 32,
177
+ "pin_memory": true
178
+ },
179
+ "gradient_accumulation_step": 1,
180
+ "total_training_steps": 50000,
181
+ "save_summary_steps": 500,
182
+ "save_checkpoints_steps": 10000,
183
+ "valid_interval": 10000,
184
+ "keep_checkpoint_max": 5,
185
+ "multi_speaker_training": false, // True: train multi-speaker model; False: training single-speaker model;
186
+ "max_epoch": -1,
187
+ // -1 means no limit
188
+ "save_checkpoint_stride": [
189
+ 5,
190
+ 20
191
+ ],
192
+ // unit is epoch
193
+ "keep_last": [
194
+ 3,
195
+ -1
196
+ ],
197
+ // -1 means infinite, if one number will broadcast
198
+ "run_eval": [
199
+ false,
200
+ true
201
+ ],
202
+ // Batchsampler
203
+ "sampler": {
204
+ "holistic_shuffle": true,
205
+ "drop_last": true
206
+ },
207
+ // Dataloader
208
+ "dataloader": {
209
+ "num_worker": 32,
210
+ "pin_memory": true
211
+ },
212
+ // Trackers
213
+ "tracker": [
214
+ "tensorboard"
215
+ // "wandb",
216
+ // "cometml",
217
+ // "mlflow",
218
+ ],
219
+ },
220
+ }
config/comosvc.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "DiffComoSVC",
4
+ "task_type": "svc",
5
+ "use_custom_dataset": false,
6
+ "preprocess": {
7
+ // data augmentations
8
+ "use_pitch_shift": false,
9
+ "use_formant_shift": false,
10
+ "use_time_stretch": false,
11
+ "use_equalizer": false,
12
+ // acoustic features
13
+ "extract_mel": true,
14
+ "mel_min_max_norm": true,
15
+ "extract_pitch": true,
16
+ "pitch_extractor": "parselmouth",
17
+ "extract_uv": true,
18
+ "extract_energy": true,
19
+ // content features
20
+ "extract_whisper_feature": false,
21
+ "whisper_sample_rate": 16000,
22
+ "extract_contentvec_feature": false,
23
+ "contentvec_sample_rate": 16000,
24
+ "extract_wenet_feature": false,
25
+ "wenet_sample_rate": 16000,
26
+ "extract_mert_feature": false,
27
+ "mert_sample_rate": 16000,
28
+ // Default config for whisper
29
+ "whisper_frameshift": 0.01,
30
+ "whisper_downsample_rate": 2,
31
+ // Default config for content vector
32
+ "contentvec_frameshift": 0.02,
33
+ // Default config for mert
34
+ "mert_model": "m-a-p/MERT-v1-330M",
35
+ "mert_feature_layer": -1,
36
+ "mert_hop_size": 320,
37
+ // 24k
38
+ "mert_frameshit": 0.01333,
39
+ // 10ms
40
+ "wenet_frameshift": 0.01,
41
+ // wenetspeech is 4, gigaspeech is 6
42
+ "wenet_downsample_rate": 4,
43
+ // Default config
44
+ "n_mel": 100,
45
+ "win_size": 1024,
46
+ // todo
47
+ "hop_size": 256,
48
+ "sample_rate": 24000,
49
+ "n_fft": 1024,
50
+ // todo
51
+ "fmin": 0,
52
+ "fmax": 12000,
53
+ // todo
54
+ "f0_min": 50,
55
+ // ~C2
56
+ "f0_max": 1100,
57
+ //1100, // ~C6(1100), ~G5(800)
58
+ "pitch_bin": 256,
59
+ "pitch_max": 1100.0,
60
+ "pitch_min": 50.0,
61
+ "is_label": true,
62
+ "is_mu_law": true,
63
+ "bits": 8,
64
+ "mel_min_max_stats_dir": "mel_min_max_stats",
65
+ "whisper_dir": "whisper",
66
+ "contentvec_dir": "contentvec",
67
+ "wenet_dir": "wenet",
68
+ "mert_dir": "mert",
69
+ // Extract content features using dataloader
70
+ "pin_memory": true,
71
+ "num_workers": 8,
72
+ "content_feature_batch_size": 16,
73
+ // Features used for model training
74
+ "use_mel": true,
75
+ "use_min_max_norm_mel": true,
76
+ "use_frame_pitch": true,
77
+ "use_uv": true,
78
+ "use_frame_energy": true,
79
+ "use_log_scale_pitch": false,
80
+ "use_log_scale_energy": false,
81
+ "use_spkid": true,
82
+ // Meta file
83
+ "train_file": "train.json",
84
+ "valid_file": "test.json",
85
+ "spk2id": "singers.json",
86
+ "utt2spk": "utt2singer"
87
+ },
88
+ "model": {
89
+ "teacher_model_path": "[Your Teacher Model Path].bin",
90
+ "condition_encoder": {
91
+ "merge_mode": "add",
92
+ "input_melody_dim": 1,
93
+ "use_log_f0": true,
94
+ "n_bins_melody": 256,
95
+ //# Quantization (0 for not quantization)
96
+ "output_melody_dim": 384,
97
+ "input_loudness_dim": 1,
98
+ "use_log_loudness": true,
99
+ "n_bins_loudness": 256,
100
+ "output_loudness_dim": 384,
101
+ "use_whisper": false,
102
+ "use_contentvec": false,
103
+ "use_wenet": false,
104
+ "use_mert": false,
105
+ "whisper_dim": 1024,
106
+ "contentvec_dim": 256,
107
+ "mert_dim": 256,
108
+ "wenet_dim": 512,
109
+ "content_encoder_dim": 384,
110
+ "output_singer_dim": 384,
111
+ "singer_table_size": 512,
112
+ "output_content_dim": 384,
113
+ "use_spkid": true
114
+ },
115
+ "comosvc": {
116
+ "distill": false,
117
+ // conformer encoder
118
+ "input_dim": 384,
119
+ "output_dim": 100,
120
+ "n_heads": 2,
121
+ "n_layers": 6,
122
+ "filter_channels": 512,
123
+ "dropout": 0.1,
124
+ // karras diffusion
125
+ "P_mean": -1.2,
126
+ "P_std": 1.2,
127
+ "sigma_data": 0.5,
128
+ "sigma_min": 0.002,
129
+ "sigma_max": 80,
130
+ "rho": 7,
131
+ "n_timesteps": 40,
132
+ },
133
+ "diffusion": {
134
+ // Diffusion steps encoder
135
+ "step_encoder": {
136
+ "dim_raw_embedding": 128,
137
+ "dim_hidden_layer": 512,
138
+ "activation": "SiLU",
139
+ "num_layer": 2,
140
+ "max_period": 10000
141
+ },
142
+ // Diffusion decoder
143
+ "model_type": "bidilconv",
144
+ // bidilconv, unet2d, TODO: unet1d
145
+ "bidilconv": {
146
+ "base_channel": 384,
147
+ "n_res_block": 20,
148
+ "conv_kernel_size": 3,
149
+ "dilation_cycle_length": 4,
150
+ // specially, 1 means no dilation
151
+ "conditioner_size": 100
152
+ }
153
+ },
154
+ },
155
+ "train": {
156
+ // Basic settings
157
+ "fast_steps": 0,
158
+ "batch_size": 32,
159
+ "gradient_accumulation_step": 1,
160
+ "max_epoch": -1,
161
+ // -1 means no limit
162
+ "save_checkpoint_stride": [
163
+ 10,
164
+ 100
165
+ ],
166
+ // unit is epoch
167
+ "keep_last": [
168
+ 3,
169
+ -1
170
+ ],
171
+ // -1 means infinite, if one number will broadcast
172
+ "run_eval": [
173
+ false,
174
+ true
175
+ ],
176
+ // if one number will broadcast
177
+ // Fix the random seed
178
+ "random_seed": 10086,
179
+ // Batchsampler
180
+ "sampler": {
181
+ "holistic_shuffle": true,
182
+ "drop_last": true
183
+ },
184
+ // Dataloader
185
+ "dataloader": {
186
+ "num_worker": 32,
187
+ "pin_memory": true
188
+ },
189
+ // Trackers
190
+ "tracker": [
191
+ "tensorboard"
192
+ // "wandb",
193
+ // "cometml",
194
+ // "mlflow",
195
+ ],
196
+ // Optimizer
197
+ "optimizer": "AdamW",
198
+ "adamw": {
199
+ "lr": 4.0e-4
200
+ // nn model lr
201
+ },
202
+ // LR Scheduler
203
+ "scheduler": "ReduceLROnPlateau",
204
+ "reducelronplateau": {
205
+ "factor": 0.8,
206
+ "patience": 10,
207
+ // unit is epoch
208
+ "min_lr": 1.0e-4
209
+ }
210
+ },
211
+ "inference": {
212
+ "comosvc": {
213
+ "inference_steps": 40
214
+ }
215
+ }
216
+ }
config/diffusion.json ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // FIXME: THESE ARE LEGACY
3
+ "base_config": "config/base.json",
4
+ "model_type": "diffusion",
5
+ "task_type": "svc",
6
+ "use_custom_dataset": false,
7
+ "preprocess": {
8
+ // data augmentations
9
+ "use_pitch_shift": false,
10
+ "use_formant_shift": false,
11
+ "use_time_stretch": false,
12
+ "use_equalizer": false,
13
+ // acoustic features
14
+ "extract_mel": true,
15
+ "mel_min_max_norm": true,
16
+ "extract_pitch": true,
17
+ "pitch_extractor": "parselmouth",
18
+ "extract_uv": true,
19
+ "extract_energy": true,
20
+ // content features
21
+ "extract_whisper_feature": false,
22
+ "whisper_sample_rate": 16000,
23
+ "extract_contentvec_feature": false,
24
+ "contentvec_sample_rate": 16000,
25
+ "extract_wenet_feature": false,
26
+ "wenet_sample_rate": 16000,
27
+ "extract_mert_feature": false,
28
+ "mert_sample_rate": 16000,
29
+ // Default config for whisper
30
+ "whisper_frameshift": 0.01,
31
+ "whisper_downsample_rate": 2,
32
+ // Default config for content vector
33
+ "contentvec_frameshift": 0.02,
34
+ // Default config for mert
35
+ "mert_model": "m-a-p/MERT-v1-330M",
36
+ "mert_feature_layer": -1,
37
+ "mert_hop_size": 320,
38
+ // 24k
39
+ "mert_frameshit": 0.01333,
40
+ // 10ms
41
+ "wenet_frameshift": 0.01,
42
+ // wenetspeech is 4, gigaspeech is 6
43
+ "wenet_downsample_rate": 4,
44
+ // Default config
45
+ "n_mel": 100,
46
+ "win_size": 1024,
47
+ // todo
48
+ "hop_size": 256,
49
+ "sample_rate": 24000,
50
+ "n_fft": 1024,
51
+ // todo
52
+ "fmin": 0,
53
+ "fmax": 12000,
54
+ // todo
55
+ "f0_min": 50,
56
+ // ~C2
57
+ "f0_max": 1100,
58
+ //1100, // ~C6(1100), ~G5(800)
59
+ "pitch_bin": 256,
60
+ "pitch_max": 1100.0,
61
+ "pitch_min": 50.0,
62
+ "is_label": true,
63
+ "is_mu_law": true,
64
+ "bits": 8,
65
+ "mel_min_max_stats_dir": "mel_min_max_stats",
66
+ "whisper_dir": "whisper",
67
+ "contentvec_dir": "contentvec",
68
+ "wenet_dir": "wenet",
69
+ "mert_dir": "mert",
70
+ // Extract content features using dataloader
71
+ "pin_memory": true,
72
+ "num_workers": 8,
73
+ "content_feature_batch_size": 16,
74
+ // Features used for model training
75
+ "use_mel": true,
76
+ "use_min_max_norm_mel": true,
77
+ "use_frame_pitch": true,
78
+ "use_uv": true,
79
+ "use_frame_energy": true,
80
+ "use_log_scale_pitch": false,
81
+ "use_log_scale_energy": false,
82
+ "use_spkid": true,
83
+ // Meta file
84
+ "train_file": "train.json",
85
+ "valid_file": "test.json",
86
+ "spk2id": "singers.json",
87
+ "utt2spk": "utt2singer"
88
+ },
89
+ "model": {
90
+ "condition_encoder": {
91
+ "merge_mode": "add",
92
+ "input_melody_dim": 1,
93
+ "use_log_f0": true,
94
+ "n_bins_melody": 256,
95
+ //# Quantization (0 for not quantization)
96
+ "output_melody_dim": 384,
97
+ "input_loudness_dim": 1,
98
+ "use_log_loudness": true,
99
+ "n_bins_loudness": 256,
100
+ "output_loudness_dim": 384,
101
+ "use_whisper": false,
102
+ "use_contentvec": false,
103
+ "use_wenet": false,
104
+ "use_mert": false,
105
+ "whisper_dim": 1024,
106
+ "contentvec_dim": 256,
107
+ "mert_dim": 256,
108
+ "wenet_dim": 512,
109
+ "content_encoder_dim": 384,
110
+ "output_singer_dim": 384,
111
+ "singer_table_size": 512,
112
+ "output_content_dim": 384,
113
+ "use_spkid": true
114
+ },
115
+ // FIXME: FOLLOWING ARE NEW!!
116
+ "diffusion": {
117
+ "scheduler": "ddpm",
118
+ "scheduler_settings": {
119
+ "num_train_timesteps": 1000,
120
+ "beta_start": 1.0e-4,
121
+ "beta_end": 0.02,
122
+ "beta_schedule": "linear"
123
+ },
124
+ // Diffusion steps encoder
125
+ "step_encoder": {
126
+ "dim_raw_embedding": 128,
127
+ "dim_hidden_layer": 512,
128
+ "activation": "SiLU",
129
+ "num_layer": 2,
130
+ "max_period": 10000
131
+ },
132
+ // Diffusion decoder
133
+ "model_type": "bidilconv",
134
+ // bidilconv, unet2d, TODO: unet1d
135
+ "bidilconv": {
136
+ "base_channel": 384,
137
+ "n_res_block": 20,
138
+ "conv_kernel_size": 3,
139
+ "dilation_cycle_length": 4,
140
+ // specially, 1 means no dilation
141
+ "conditioner_size": 384
142
+ },
143
+ "unet2d": {
144
+ "in_channels": 1,
145
+ "out_channels": 1,
146
+ "down_block_types": [
147
+ "CrossAttnDownBlock2D",
148
+ "CrossAttnDownBlock2D",
149
+ "CrossAttnDownBlock2D",
150
+ "DownBlock2D"
151
+ ],
152
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
153
+ "up_block_types": [
154
+ "UpBlock2D",
155
+ "CrossAttnUpBlock2D",
156
+ "CrossAttnUpBlock2D",
157
+ "CrossAttnUpBlock2D"
158
+ ],
159
+ "only_cross_attention": false
160
+ }
161
+ }
162
+ },
163
+ // FIXME: FOLLOWING ARE NEW!!
164
+ "train": {
165
+ // Basic settings
166
+ "batch_size": 64,
167
+ "gradient_accumulation_step": 1,
168
+ "max_epoch": -1,
169
+ // -1 means no limit
170
+ "save_checkpoint_stride": [
171
+ 5,
172
+ 20
173
+ ],
174
+ // unit is epoch
175
+ "keep_last": [
176
+ 3,
177
+ -1
178
+ ],
179
+ // -1 means infinite, if one number will broadcast
180
+ "run_eval": [
181
+ false,
182
+ true
183
+ ],
184
+ // if one number will broadcast
185
+ // Fix the random seed
186
+ "random_seed": 10086,
187
+ // Batchsampler
188
+ "sampler": {
189
+ "holistic_shuffle": true,
190
+ "drop_last": true
191
+ },
192
+ // Dataloader
193
+ "dataloader": {
194
+ "num_worker": 32,
195
+ "pin_memory": true
196
+ },
197
+ // Trackers
198
+ "tracker": [
199
+ "tensorboard"
200
+ // "wandb",
201
+ // "cometml",
202
+ // "mlflow",
203
+ ],
204
+ // Optimizer
205
+ "optimizer": "AdamW",
206
+ "adamw": {
207
+ "lr": 4.0e-4
208
+ // nn model lr
209
+ },
210
+ // LR Scheduler
211
+ "scheduler": "ReduceLROnPlateau",
212
+ "reducelronplateau": {
213
+ "factor": 0.8,
214
+ "patience": 10,
215
+ // unit is epoch
216
+ "min_lr": 1.0e-4
217
+ }
218
+ },
219
+ "inference": {
220
+ "diffusion": {
221
+ "scheduler": "pndm",
222
+ "scheduler_settings": {
223
+ "num_inference_timesteps": 1000
224
+ }
225
+ }
226
+ }
227
+ }
config/fs2.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/tts.json",
3
+ "model_type": "FastSpeech2",
4
+ "task_type": "tts",
5
+ "dataset": ["LJSpeech"],
6
+ "preprocess": {
7
+ // acoustic features
8
+ "extract_audio": true,
9
+ "extract_mel": true,
10
+ "mel_extract_mode": "taco",
11
+ "mel_min_max_norm": false,
12
+ "extract_pitch": true,
13
+ "extract_uv": false,
14
+ "pitch_extractor": "dio",
15
+ "extract_energy": true,
16
+ "energy_extract_mode": "from_tacotron_stft",
17
+ "extract_duration": true,
18
+ "use_phone": true,
19
+ "pitch_norm": true,
20
+ "energy_norm": true,
21
+ "pitch_remove_outlier": true,
22
+ "energy_remove_outlier": true,
23
+
24
+ // Default config
25
+ "n_mel": 80,
26
+ "win_size": 1024, // todo
27
+ "hop_size": 256,
28
+ "sample_rate": 22050,
29
+ "n_fft": 1024, // todo
30
+ "fmin": 0,
31
+ "fmax": 8000, // todo
32
+ "raw_data": "raw_data",
33
+ "text_cleaners": ["english_cleaners"],
34
+ "f0_min": 71, // ~C2
35
+ "f0_max": 800, //1100, // ~C6(1100), ~G5(800)
36
+ "pitch_bin": 256,
37
+ "pitch_max": 1100.0,
38
+ "pitch_min": 50.0,
39
+ "is_label": true,
40
+ "is_mu_law": true,
41
+ "bits": 8,
42
+
43
+ "mel_min_max_stats_dir": "mel_min_max_stats",
44
+ "whisper_dir": "whisper",
45
+ "content_vector_dir": "content_vector",
46
+ "wenet_dir": "wenet",
47
+ "mert_dir": "mert",
48
+ "spk2id":"spk2id.json",
49
+ "utt2spk":"utt2spk",
50
+
51
+ // Features used for model training
52
+ "use_mel": true,
53
+ "use_min_max_norm_mel": false,
54
+ "use_frame_pitch": false,
55
+ "use_frame_energy": false,
56
+ "use_phone_pitch": true,
57
+ "use_phone_energy": true,
58
+ "use_log_scale_pitch": false,
59
+ "use_log_scale_energy": false,
60
+ "use_spkid": false,
61
+ "align_mel_duration": true,
62
+ "text_cleaners": ["english_cleaners"],
63
+ "phone_extractor": "lexicon", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
64
+ },
65
+ "model": {
66
+ // Settings for transformer
67
+ "transformer": {
68
+ "encoder_layer": 4,
69
+ "encoder_head": 2,
70
+ "encoder_hidden": 256,
71
+ "decoder_layer": 6,
72
+ "decoder_head": 2,
73
+ "decoder_hidden": 256,
74
+ "conv_filter_size": 1024,
75
+ "conv_kernel_size": [9, 1],
76
+ "encoder_dropout": 0.2,
77
+ "decoder_dropout": 0.2
78
+ },
79
+
80
+ // Settings for variance_predictor
81
+ "variance_predictor":{
82
+ "filter_size": 256,
83
+ "kernel_size": 3,
84
+ "dropout": 0.5
85
+ },
86
+ "variance_embedding":{
87
+ "pitch_quantization": "linear", // support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
88
+ "energy_quantization": "linear", // support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
89
+ "n_bins": 256
90
+ },
91
+ "max_seq_len": 1000
92
+ },
93
+ "train":{
94
+ "batch_size": 16,
95
+ "sort_sample": true,
96
+ "drop_last": true,
97
+ "group_size": 4,
98
+ "grad_clip_thresh": 1.0,
99
+ "dataloader": {
100
+ "num_worker": 8,
101
+ "pin_memory": true
102
+ },
103
+ "lr_scheduler":{
104
+ "num_warmup": 4000
105
+ },
106
+ // LR Scheduler
107
+ "scheduler": "NoamLR",
108
+ // Optimizer
109
+ "optimizer": "Adam",
110
+ "adam": {
111
+ "lr": 0.0625,
112
+ "betas": [0.9, 0.98],
113
+ "eps": 0.000000001,
114
+ "weight_decay": 0.0
115
+ },
116
+ }
117
+
118
+ }
config/ns2.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "NaturalSpeech2",
4
+ "dataset": ["LibriTTS"],
5
+ "preprocess": {
6
+ "use_mel": false,
7
+ "use_code": true,
8
+ "use_spkid": true,
9
+ "use_pitch": true,
10
+ "use_duration": true,
11
+ "use_phone": true,
12
+ "use_len": true,
13
+ "use_cross_reference": true,
14
+ "train_file": "train.json",
15
+ "melspec_dir": "mel",
16
+ "code_dir": "code",
17
+ "pitch_dir": "pitch",
18
+ "duration_dir": "duration",
19
+ "clip_mode": "start"
20
+ },
21
+ "model": {
22
+ "latent_dim": 128,
23
+ "prior_encoder": {
24
+ "vocab_size": 100,
25
+ "pitch_min": 50,
26
+ "pitch_max": 1100,
27
+ "pitch_bins_num": 512,
28
+ "encoder": {
29
+ "encoder_layer": 6,
30
+ "encoder_hidden": 512,
31
+ "encoder_head": 8,
32
+ "conv_filter_size": 2048,
33
+ "conv_kernel_size": 9,
34
+ "encoder_dropout": 0.2,
35
+ "use_cln": true
36
+ },
37
+ "duration_predictor": {
38
+ "input_size": 512,
39
+ "filter_size": 512,
40
+ "kernel_size": 3,
41
+ "conv_layers": 30,
42
+ "cross_attn_per_layer": 3,
43
+ "attn_head": 8,
44
+ "drop_out": 0.5
45
+ },
46
+ "pitch_predictor": {
47
+ "input_size": 512,
48
+ "filter_size": 512,
49
+ "kernel_size": 5,
50
+ "conv_layers": 30,
51
+ "cross_attn_per_layer": 3,
52
+ "attn_head": 8,
53
+ "drop_out": 0.5
54
+ }
55
+ },
56
+ "diffusion": {
57
+ "wavenet": {
58
+ "input_size": 128,
59
+ "hidden_size": 512,
60
+ "out_size": 128,
61
+ "num_layers": 40,
62
+ "cross_attn_per_layer": 3,
63
+ "dilation_cycle": 2,
64
+ "attn_head": 8,
65
+ "drop_out": 0.2
66
+ },
67
+ "beta_min": 0.05,
68
+ "beta_max": 20,
69
+ "sigma": 1.0,
70
+ "noise_factor": 1.0,
71
+ "ode_solver": "euler"
72
+ },
73
+ "prompt_encoder": {
74
+ "encoder_layer": 6,
75
+ "encoder_hidden": 512,
76
+ "encoder_head": 8,
77
+ "conv_filter_size": 2048,
78
+ "conv_kernel_size": 9,
79
+ "encoder_dropout": 0.2,
80
+ "use_cln": false
81
+ },
82
+ "query_emb": {
83
+ "query_token_num": 32,
84
+ "hidden_size": 512,
85
+ "head_num": 8
86
+ }
87
+ }
88
+ }
config/transformer.json ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "Transformer",
4
+ "task_type": "svc",
5
+ "use_custom_dataset": false,
6
+ "preprocess": {
7
+ // data augmentations
8
+ "use_pitch_shift": false,
9
+ "use_formant_shift": false,
10
+ "use_time_stretch": false,
11
+ "use_equalizer": false,
12
+ // acoustic features
13
+ "extract_mel": true,
14
+ "mel_min_max_norm": true,
15
+ "extract_pitch": true,
16
+ "pitch_extractor": "parselmouth",
17
+ "extract_uv": true,
18
+ "extract_energy": true,
19
+ // content features
20
+ "extract_whisper_feature": false,
21
+ "whisper_sample_rate": 16000,
22
+ "extract_contentvec_feature": false,
23
+ "contentvec_sample_rate": 16000,
24
+ "extract_wenet_feature": false,
25
+ "wenet_sample_rate": 16000,
26
+ "extract_mert_feature": false,
27
+ "mert_sample_rate": 16000,
28
+ // Default config for whisper
29
+ "whisper_frameshift": 0.01,
30
+ "whisper_downsample_rate": 2,
31
+ // Default config for content vector
32
+ "contentvec_frameshift": 0.02,
33
+ // Default config for mert
34
+ "mert_model": "m-a-p/MERT-v1-330M",
35
+ "mert_feature_layer": -1,
36
+ "mert_hop_size": 320,
37
+ // 24k
38
+ "mert_frameshit": 0.01333,
39
+ // 10ms
40
+ "wenet_frameshift": 0.01,
41
+ // wenetspeech is 4, gigaspeech is 6
42
+ "wenet_downsample_rate": 4,
43
+ // Default config
44
+ "n_mel": 100,
45
+ "win_size": 1024,
46
+ // todo
47
+ "hop_size": 256,
48
+ "sample_rate": 24000,
49
+ "n_fft": 1024,
50
+ // todo
51
+ "fmin": 0,
52
+ "fmax": 12000,
53
+ // todo
54
+ "f0_min": 50,
55
+ // ~C2
56
+ "f0_max": 1100,
57
+ //1100, // ~C6(1100), ~G5(800)
58
+ "pitch_bin": 256,
59
+ "pitch_max": 1100.0,
60
+ "pitch_min": 50.0,
61
+ "is_label": true,
62
+ "is_mu_law": true,
63
+ "bits": 8,
64
+ "mel_min_max_stats_dir": "mel_min_max_stats",
65
+ "whisper_dir": "whisper",
66
+ "contentvec_dir": "contentvec",
67
+ "wenet_dir": "wenet",
68
+ "mert_dir": "mert",
69
+ // Extract content features using dataloader
70
+ "pin_memory": true,
71
+ "num_workers": 8,
72
+ "content_feature_batch_size": 16,
73
+ // Features used for model training
74
+ "use_mel": true,
75
+ "use_min_max_norm_mel": true,
76
+ "use_frame_pitch": true,
77
+ "use_uv": true,
78
+ "use_frame_energy": true,
79
+ "use_log_scale_pitch": false,
80
+ "use_log_scale_energy": false,
81
+ "use_spkid": true,
82
+ // Meta file
83
+ "train_file": "train.json",
84
+ "valid_file": "test.json",
85
+ "spk2id": "singers.json",
86
+ "utt2spk": "utt2singer"
87
+ },
88
+ "model": {
89
+ "condition_encoder": {
90
+ "merge_mode": "add",
91
+ "input_melody_dim": 1,
92
+ "use_log_f0": true,
93
+ "n_bins_melody": 256,
94
+ //# Quantization (0 for not quantization)
95
+ "output_melody_dim": 384,
96
+ "input_loudness_dim": 1,
97
+ "use_log_loudness": true,
98
+ "n_bins_loudness": 256,
99
+ "output_loudness_dim": 384,
100
+ "use_whisper": false,
101
+ "use_contentvec": true,
102
+ "use_wenet": false,
103
+ "use_mert": false,
104
+ "whisper_dim": 1024,
105
+ "contentvec_dim": 256,
106
+ "mert_dim": 256,
107
+ "wenet_dim": 512,
108
+ "content_encoder_dim": 384,
109
+ "output_singer_dim": 384,
110
+ "singer_table_size": 512,
111
+ "output_content_dim": 384,
112
+ "use_spkid": true
113
+ },
114
+ "transformer": {
115
+ "type": "conformer",
116
+ // 'conformer' or 'transformer'
117
+ "input_dim": 384,
118
+ "output_dim": 100,
119
+ "n_heads": 2,
120
+ "n_layers": 6,
121
+ "filter_channels": 512,
122
+ "dropout": 0.1,
123
+ }
124
+ },
125
+ "train": {
126
+ // Basic settings
127
+ "batch_size": 64,
128
+ "gradient_accumulation_step": 1,
129
+ "max_epoch": -1,
130
+ // -1 means no limit
131
+ "save_checkpoint_stride": [
132
+ 10,
133
+ 100
134
+ ],
135
+ // unit is epoch
136
+ "keep_last": [
137
+ 3,
138
+ -1
139
+ ],
140
+ // -1 means infinite, if one number will broadcast
141
+ "run_eval": [
142
+ false,
143
+ true
144
+ ],
145
+ // if one number will broadcast
146
+ // Fix the random seed
147
+ "random_seed": 10086,
148
+ // Batchsampler
149
+ "sampler": {
150
+ "holistic_shuffle": true,
151
+ "drop_last": true
152
+ },
153
+ // Dataloader
154
+ "dataloader": {
155
+ "num_worker": 32,
156
+ "pin_memory": true
157
+ },
158
+ // Trackers
159
+ "tracker": [
160
+ "tensorboard"
161
+ // "wandb",
162
+ // "cometml",
163
+ // "mlflow",
164
+ ],
165
+ // Optimizer
166
+ "optimizer": "AdamW",
167
+ "adamw": {
168
+ "lr": 4.0e-4
169
+ // nn model lr
170
+ },
171
+ // LR Scheduler
172
+ "scheduler": "ReduceLROnPlateau",
173
+ "reducelronplateau": {
174
+ "factor": 0.8,
175
+ "patience": 10,
176
+ // unit is epoch
177
+ "min_lr": 1.0e-4
178
+ }
179
+ }
180
+ }
config/tts.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "supported_model_type": [
4
+ "Fastspeech2",
5
+ "VITS",
6
+ "VALLE",
7
+ ],
8
+ "task_type": "tts",
9
+ "preprocess": {
10
+ "language": "en-us",
11
+ // linguistic features
12
+ "extract_phone": true,
13
+ "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
14
+ "lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
15
+ // Directory names of processed data or extracted features
16
+ "phone_dir": "phones",
17
+ "use_phone": true,
18
+ },
19
+ "model": {
20
+ "text_token_num": 512,
21
+ }
22
+
23
+ }
config/valle.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/tts.json",
3
+ "model_type": "VALLE",
4
+ "task_type": "tts",
5
+ "dataset": [
6
+ "libritts"
7
+ ],
8
+ "preprocess": {
9
+ "extract_phone": true,
10
+ "phone_extractor": "espeak", // phoneme extractor: espeak, pypinyin, pypinyin_initials_finals or lexicon
11
+ "extract_acoustic_token": true,
12
+ "acoustic_token_extractor": "Encodec", // acoustic token extractor: encodec, dac(todo)
13
+ "acoustic_token_dir": "acoutic_tokens",
14
+ "use_text": false,
15
+ "use_phone": true,
16
+ "use_acoustic_token": true,
17
+ "symbols_dict": "symbols.dict",
18
+ "min_duration": 0.5, // the duration lowerbound to filter the audio with duration < min_duration
19
+ "max_duration": 14, // the duration uperbound to filter the audio with duration > max_duration.
20
+ "sample_rate": 24000,
21
+ "codec_hop_size": 320
22
+ },
23
+ "model": {
24
+ "text_token_num": 512,
25
+ "audio_token_num": 1024,
26
+ "decoder_dim": 1024, // embedding dimension of the decoder model
27
+ "nhead": 16, // number of attention heads in the decoder layers
28
+ "num_decoder_layers": 12, // number of decoder layers
29
+ "norm_first": true, // pre or post Normalization.
30
+ "add_prenet": false, // whether add PreNet after Inputs
31
+ "prefix_mode": 0, // mode for how to prefix VALL-E NAR Decoder, 0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance
32
+ "share_embedding": true, // share the parameters of the output projection layer with the parameters of the acoustic embedding
33
+ "nar_scale_factor": 1, // model scale factor which will be assigned different meanings in different models
34
+ "prepend_bos": false, // whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs
35
+ "num_quantizers": 8, // numbert of the audio quantization layers
36
+ // "scaling_xformers": false, // Apply Reworked Conformer scaling on Transformers
37
+ },
38
+ "train": {
39
+ "ddp": false,
40
+ "train_stage": 1, // 0: train all modules, For VALL_E, support 1: AR Decoder 2: NAR Decoder(s)
41
+ "max_epoch": 20,
42
+ "optimizer": "AdamW",
43
+ "scheduler": "cosine",
44
+ "warmup_steps": 16000, // number of steps that affects how rapidly the learning rate decreases
45
+ "base_lr": 1e-4, // base learning rate."
46
+ "valid_interval": 1000,
47
+ "log_epoch_step": 1000,
48
+ "save_checkpoint_stride": [
49
+ 1,
50
+ 1
51
+ ]
52
+ }
53
+ }
config/vits.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/tts.json",
3
+ "model_type": "VITS",
4
+ "task_type": "tts",
5
+ "preprocess": {
6
+ "extract_phone": true,
7
+ "extract_mel": true,
8
+ "n_mel": 80,
9
+ "fmin": 0,
10
+ "fmax": null,
11
+ "extract_linear_spec": true,
12
+ "extract_audio": true,
13
+ "use_linear": true,
14
+ "use_mel": true,
15
+ "use_audio": true,
16
+ "use_text": false,
17
+ "use_phone": true,
18
+ "lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
19
+ "n_fft": 1024,
20
+ "win_size": 1024,
21
+ "hop_size": 256,
22
+ "segment_size": 8192,
23
+ "text_cleaners": [
24
+ "english_cleaners"
25
+ ]
26
+ },
27
+ "model": {
28
+ "text_token_num": 512,
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0.1,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [
38
+ 3,
39
+ 7,
40
+ 11
41
+ ],
42
+ "resblock_dilation_sizes": [
43
+ [
44
+ 1,
45
+ 3,
46
+ 5
47
+ ],
48
+ [
49
+ 1,
50
+ 3,
51
+ 5
52
+ ],
53
+ [
54
+ 1,
55
+ 3,
56
+ 5
57
+ ]
58
+ ],
59
+ "upsample_rates": [
60
+ 8,
61
+ 8,
62
+ 2,
63
+ 2
64
+ ],
65
+ "upsample_initial_channel": 512,
66
+ "upsample_kernel_sizes": [
67
+ 16,
68
+ 16,
69
+ 4,
70
+ 4
71
+ ],
72
+ "n_layers_q": 3,
73
+ "use_spectral_norm": false,
74
+ "n_speakers": 0, // number of speakers, while be automatically set if n_speakers is 0 and multi_speaker_training is true
75
+ "gin_channels": 256,
76
+ "use_sdp": true
77
+ },
78
+ "train": {
79
+ "fp16_run": true,
80
+ "learning_rate": 2e-4,
81
+ "betas": [
82
+ 0.8,
83
+ 0.99
84
+ ],
85
+ "eps": 1e-9,
86
+ "batch_size": 16,
87
+ "lr_decay": 0.999875,
88
+ // "segment_size": 8192,
89
+ "init_lr_ratio": 1,
90
+ "warmup_epochs": 0,
91
+ "c_mel": 45,
92
+ "c_kl": 1.0,
93
+ "AdamW": {
94
+ "betas": [
95
+ 0.8,
96
+ 0.99
97
+ ],
98
+ "eps": 1e-9,
99
+ }
100
+ }
101
+ }
config/vitssvc.json ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "model_type": "VITS",
4
+ "task_type": "svc",
5
+ "preprocess": {
6
+ "extract_phone": false,
7
+ "extract_mel": true,
8
+ "extract_linear_spec": true,
9
+ "extract_audio": true,
10
+ "use_linear": true,
11
+ "use_mel": true,
12
+ "use_audio": true,
13
+ "use_text": false,
14
+ "use_phone": true,
15
+
16
+ "fmin": 0,
17
+ "fmax": null,
18
+ "f0_min": 50,
19
+ "f0_max": 1100,
20
+ // f0_bin in sovits
21
+ "pitch_bin": 256,
22
+ // filter_length in sovits
23
+ "n_fft": 2048,
24
+ // hop_length in sovits
25
+ "hop_size": 512,
26
+ // win_length in sovits
27
+ "win_size": 2048,
28
+ "segment_size": 8192,
29
+ "n_mel": 100,
30
+ "sample_rate": 44100,
31
+
32
+ "mel_min_max_stats_dir": "mel_min_max_stats",
33
+ "whisper_dir": "whisper",
34
+ "contentvec_dir": "contentvec",
35
+ "wenet_dir": "wenet",
36
+ "mert_dir": "mert",
37
+ },
38
+ "model": {
39
+ "condition_encoder": {
40
+ "merge_mode": "add",
41
+ "input_melody_dim": 1,
42
+ "use_log_f0": true,
43
+ "n_bins_melody": 256,
44
+ //# Quantization (0 for not quantization)
45
+ "output_melody_dim": 196,
46
+ "input_loudness_dim": 1,
47
+ "use_log_loudness": false,
48
+ "n_bins_loudness": 256,
49
+ "output_loudness_dim": 196,
50
+ "use_whisper": false,
51
+ "use_contentvec": false,
52
+ "use_wenet": false,
53
+ "use_mert": false,
54
+ "whisper_dim": 1024,
55
+ "contentvec_dim": 256,
56
+ "mert_dim": 256,
57
+ "wenet_dim": 512,
58
+ "content_encoder_dim": 196,
59
+ "output_singer_dim": 196,
60
+ "singer_table_size": 512,
61
+ "output_content_dim": 196,
62
+ "use_spkid": true
63
+ },
64
+ "vits": {
65
+ "filter_channels": 256,
66
+ "gin_channels": 256,
67
+ "hidden_channels": 192,
68
+ "inter_channels": 192,
69
+ "kernel_size": 3,
70
+ "n_flow_layer": 4,
71
+ "n_heads": 2,
72
+ "n_layers": 6,
73
+ "n_layers_q": 3,
74
+ "n_speakers": 512,
75
+ "p_dropout": 0.1,
76
+ "ssl_dim": 256,
77
+ "use_spectral_norm": false,
78
+ },
79
+ "generator": "hifigan",
80
+ "generator_config": {
81
+ "hifigan": {
82
+ "resblock": "1",
83
+ "resblock_kernel_sizes": [
84
+ 3,
85
+ 7,
86
+ 11
87
+ ],
88
+ "upsample_rates": [
89
+ 8,8,2,2,2
90
+ ],
91
+ "upsample_kernel_sizes": [
92
+ 16,16,4,4,4
93
+ ],
94
+ "upsample_initial_channel": 512,
95
+ "resblock_dilation_sizes": [
96
+ [1,3,5],
97
+ [1,3,5],
98
+ [1,3,5]
99
+ ]
100
+ },
101
+ "melgan": {
102
+ "ratios": [8, 8, 2, 2, 2],
103
+ "ngf": 32,
104
+ "n_residual_layers": 3,
105
+ "num_D": 3,
106
+ "ndf": 16,
107
+ "n_layers": 4,
108
+ "downsampling_factor": 4
109
+ },
110
+ "bigvgan": {
111
+ "resblock": "1",
112
+ "activation": "snakebeta",
113
+ "snake_logscale": true,
114
+ "upsample_rates": [
115
+ 8,8,2,2,2,
116
+ ],
117
+ "upsample_kernel_sizes": [
118
+ 16,16,4,4,4,
119
+ ],
120
+ "upsample_initial_channel": 512,
121
+ "resblock_kernel_sizes": [
122
+ 3,
123
+ 7,
124
+ 11
125
+ ],
126
+ "resblock_dilation_sizes": [
127
+ [1,3,5],
128
+ [1,3,5],
129
+ [1,3,5]
130
+ ]
131
+ },
132
+ "nsfhifigan": {
133
+ "resblock": "1",
134
+ "harmonic_num": 8,
135
+ "upsample_rates": [
136
+ 8,8,2,2,2,
137
+ ],
138
+ "upsample_kernel_sizes": [
139
+ 16,16,4,4,4,
140
+ ],
141
+ "upsample_initial_channel": 768,
142
+ "resblock_kernel_sizes": [
143
+ 3,
144
+ 7,
145
+ 11
146
+ ],
147
+ "resblock_dilation_sizes": [
148
+ [1,3,5],
149
+ [1,3,5],
150
+ [1,3,5]
151
+ ]
152
+ },
153
+ "apnet": {
154
+ "ASP_channel": 512,
155
+ "ASP_resblock_kernel_sizes": [3,7,11],
156
+ "ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
157
+ "ASP_input_conv_kernel_size": 7,
158
+ "ASP_output_conv_kernel_size": 7,
159
+
160
+ "PSP_channel": 512,
161
+ "PSP_resblock_kernel_sizes": [3,7,11],
162
+ "PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
163
+ "PSP_input_conv_kernel_size": 7,
164
+ "PSP_output_R_conv_kernel_size": 7,
165
+ "PSP_output_I_conv_kernel_size": 7,
166
+ }
167
+ },
168
+ },
169
+ "train": {
170
+ "fp16_run": true,
171
+ "learning_rate": 2e-4,
172
+ "betas": [
173
+ 0.8,
174
+ 0.99
175
+ ],
176
+ "eps": 1e-9,
177
+ "batch_size": 16,
178
+ "lr_decay": 0.999875,
179
+ // "segment_size": 8192,
180
+ "init_lr_ratio": 1,
181
+ "warmup_epochs": 0,
182
+ "c_mel": 45,
183
+ "c_kl": 1.0,
184
+ "AdamW": {
185
+ "betas": [
186
+ 0.8,
187
+ 0.99
188
+ ],
189
+ "eps": 1e-9,
190
+ }
191
+ }
192
+ }
config/vocoder.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/base.json",
3
+ "dataset": [
4
+ "LJSpeech",
5
+ "LibriTTS",
6
+ "opencpop",
7
+ "m4singer",
8
+ "svcc",
9
+ "svcceval",
10
+ "pjs",
11
+ "opensinger",
12
+ "popbutfy",
13
+ "nus48e",
14
+ "popcs",
15
+ "kising",
16
+ "csd",
17
+ "opera",
18
+ "vctk",
19
+ "lijian",
20
+ "cdmusiceval"
21
+ ],
22
+ "task_type": "vocoder",
23
+ "preprocess": {
24
+ // acoustic features
25
+ "extract_mel": true,
26
+ "extract_pitch": false,
27
+ "extract_uv": false,
28
+ "extract_audio": true,
29
+ "extract_label": false,
30
+ "extract_one_hot": false,
31
+ "extract_amplitude_phase": false,
32
+ "pitch_extractor": "parselmouth",
33
+ // Settings for data preprocessing
34
+ "n_mel": 100,
35
+ "win_size": 1024,
36
+ "hop_size": 256,
37
+ "sample_rate": 24000,
38
+ "n_fft": 1024,
39
+ "fmin": 0,
40
+ "fmax": 12000,
41
+ "f0_min": 50,
42
+ "f0_max": 1100,
43
+ "pitch_bin": 256,
44
+ "pitch_max": 1100.0,
45
+ "pitch_min": 50.0,
46
+ "is_mu_law": false,
47
+ "bits": 8,
48
+ "cut_mel_frame": 32,
49
+ // Directory names of processed data or extracted features
50
+ "spk2id": "singers.json",
51
+ // Features used for model training
52
+ "use_mel": true,
53
+ "use_frame_pitch": false,
54
+ "use_uv": false,
55
+ "use_audio": true,
56
+ "use_label": false,
57
+ "use_one_hot": false,
58
+ "train_file": "train.json",
59
+ "valid_file": "test.json"
60
+ },
61
+ "train": {
62
+ "random_seed": 114514,
63
+ "batch_size": 64,
64
+ "gradient_accumulation_step": 1,
65
+ "max_epoch": 1000000,
66
+ "save_checkpoint_stride": [
67
+ 20
68
+ ],
69
+ "run_eval": [
70
+ true
71
+ ],
72
+ "sampler": {
73
+ "holistic_shuffle": true,
74
+ "drop_last": true
75
+ },
76
+ "dataloader": {
77
+ "num_worker": 4,
78
+ "pin_memory": true
79
+ },
80
+ "tracker": [
81
+ "tensorboard"
82
+ ],
83
+ }
84
+ }
evaluation/__init__.py ADDED
File without changes
evaluation/features/__init__.py ADDED
File without changes
evaluation/features/long_term_average_spectrum.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import librosa
7
+ from scipy import signal
8
+
9
+
10
+ def extract_ltas(audio, fs=None, n_fft=1024, hop_length=256):
11
+ """Extract Long-Term Average Spectrum for a given audio."""
12
+ if fs != None:
13
+ y, _ = librosa.load(audio, sr=fs)
14
+ else:
15
+ y, fs = librosa.load(audio)
16
+ frequency, density = signal.welch(
17
+ x=y, fs=fs, window="hann", nperseg=hop_length, nfft=n_fft
18
+ )
19
+ return frequency, density
evaluation/features/signal_to_noise_ratio.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import scipy.signal as sig
8
+ import copy
9
+ import librosa
10
+
11
+
12
+ def bandpower(ps, mode="time"):
13
+ """
14
+ estimate bandpower, see https://de.mathworks.com/help/signal/ref/bandpower.html
15
+ """
16
+ if mode == "time":
17
+ x = ps
18
+ l2norm = np.linalg.norm(x) ** 2.0 / len(x)
19
+ return l2norm
20
+ elif mode == "psd":
21
+ return sum(ps)
22
+
23
+
24
+ def getIndizesAroundPeak(arr, peakIndex, searchWidth=1000):
25
+ peakBins = []
26
+ magMax = arr[peakIndex]
27
+ curVal = magMax
28
+ for i in range(searchWidth):
29
+ newBin = peakIndex + i
30
+ if newBin >= len(arr):
31
+ break
32
+ newVal = arr[newBin]
33
+ if newVal > curVal:
34
+ break
35
+ else:
36
+ peakBins.append(int(newBin))
37
+ curVal = newVal
38
+ curVal = magMax
39
+ for i in range(searchWidth):
40
+ newBin = peakIndex - i
41
+ if newBin < 0:
42
+ break
43
+ newVal = arr[newBin]
44
+ if newVal > curVal:
45
+ break
46
+ else:
47
+ peakBins.append(int(newBin))
48
+ curVal = newVal
49
+ return np.array(list(set(peakBins)))
50
+
51
+
52
+ def freqToBin(fAxis, Freq):
53
+ return np.argmin(abs(fAxis - Freq))
54
+
55
+
56
+ def getPeakInArea(psd, faxis, estimation, searchWidthHz=10):
57
+ """
58
+ returns bin and frequency of the maximum in an area
59
+ """
60
+ binLow = freqToBin(faxis, estimation - searchWidthHz)
61
+ binHi = freqToBin(faxis, estimation + searchWidthHz)
62
+ peakbin = binLow + np.argmax(psd[binLow : binHi + 1])
63
+ return peakbin, faxis[peakbin]
64
+
65
+
66
+ def getHarmonics(fund, sr, nHarmonics=6, aliased=False):
67
+ harmonicMultipliers = np.arange(2, nHarmonics + 2)
68
+ harmonicFs = fund * harmonicMultipliers
69
+ if not aliased:
70
+ harmonicFs[harmonicFs > sr / 2] = -1
71
+ harmonicFs = np.delete(harmonicFs, harmonicFs == -1)
72
+ else:
73
+ nyqZone = np.floor(harmonicFs / (sr / 2))
74
+ oddEvenNyq = nyqZone % 2
75
+ harmonicFs = np.mod(harmonicFs, sr / 2)
76
+ harmonicFs[oddEvenNyq == 1] = (sr / 2) - harmonicFs[oddEvenNyq == 1]
77
+ return harmonicFs
78
+
79
+
80
+ def extract_snr(audio, sr=None):
81
+ """Extract Signal-to-Noise Ratio for a given audio."""
82
+ if sr != None:
83
+ audio, _ = librosa.load(audio, sr=sr)
84
+ else:
85
+ audio, sr = librosa.load(audio, sr=sr)
86
+ faxis, ps = sig.periodogram(
87
+ audio, fs=sr, window=("kaiser", 38)
88
+ ) # get periodogram, parametrized like in matlab
89
+ fundBin = np.argmax(
90
+ ps
91
+ ) # estimate fundamental at maximum amplitude, get the bin number
92
+ fundIndizes = getIndizesAroundPeak(
93
+ ps, fundBin
94
+ ) # get bin numbers around fundamental peak
95
+ fundFrequency = faxis[fundBin] # frequency of fundamental
96
+
97
+ nHarmonics = 18
98
+ harmonicFs = getHarmonics(
99
+ fundFrequency, sr, nHarmonics=nHarmonics, aliased=True
100
+ ) # get harmonic frequencies
101
+
102
+ harmonicBorders = np.zeros([2, nHarmonics], dtype=np.int16).T
103
+ fullHarmonicBins = np.array([], dtype=np.int16)
104
+ fullHarmonicBinList = []
105
+ harmPeakFreqs = []
106
+ harmPeaks = []
107
+ for i, harmonic in enumerate(harmonicFs):
108
+ searcharea = 0.1 * fundFrequency
109
+ estimation = harmonic
110
+
111
+ binNum, freq = getPeakInArea(ps, faxis, estimation, searcharea)
112
+ harmPeakFreqs.append(freq)
113
+ harmPeaks.append(ps[binNum])
114
+ allBins = getIndizesAroundPeak(ps, binNum, searchWidth=1000)
115
+ fullHarmonicBins = np.append(fullHarmonicBins, allBins)
116
+ fullHarmonicBinList.append(allBins)
117
+ harmonicBorders[i, :] = [allBins[0], allBins[-1]]
118
+
119
+ fundIndizes.sort()
120
+ pFund = bandpower(ps[fundIndizes[0] : fundIndizes[-1]]) # get power of fundamental
121
+
122
+ noisePrepared = copy.copy(ps)
123
+ noisePrepared[fundIndizes] = 0
124
+ noisePrepared[fullHarmonicBins] = 0
125
+ noiseMean = np.median(noisePrepared[noisePrepared != 0])
126
+ noisePrepared[fundIndizes] = noiseMean
127
+ noisePrepared[fullHarmonicBins] = noiseMean
128
+
129
+ noisePower = bandpower(noisePrepared)
130
+
131
+ r = 10 * np.log10(pFund / noisePower)
132
+
133
+ return r, 10 * np.log10(noisePower)
evaluation/features/singing_power_ratio.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import librosa
8
+
9
+ from utils.util import JsonHParams
10
+ from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
11
+ from utils.mel import extract_mel_features
12
+
13
+
14
+ def extract_spr(
15
+ audio,
16
+ fs=None,
17
+ hop_length=256,
18
+ win_length=1024,
19
+ n_fft=1024,
20
+ n_mels=128,
21
+ f0_min=37,
22
+ f0_max=1000,
23
+ pitch_bin=256,
24
+ pitch_max=1100.0,
25
+ pitch_min=50.0,
26
+ ):
27
+ """Compute Singing Power Ratio (SPR) from a given audio.
28
+ audio: path to the audio.
29
+ fs: sampling rate.
30
+ hop_length: hop length.
31
+ win_length: window length.
32
+ n_mels: number of mel filters.
33
+ f0_min: lower limit for f0.
34
+ f0_max: upper limit for f0.
35
+ pitch_bin: number of bins for f0 quantization.
36
+ pitch_max: upper limit for f0 quantization.
37
+ pitch_min: lower limit for f0 quantization.
38
+ """
39
+ # Load audio
40
+ if fs != None:
41
+ audio, _ = librosa.load(audio, sr=fs)
42
+ else:
43
+ audio, fs = librosa.load(audio)
44
+ audio = torch.from_numpy(audio)
45
+
46
+ # Initialize config
47
+ cfg = JsonHParams()
48
+ cfg.sample_rate = fs
49
+ cfg.hop_size = hop_length
50
+ cfg.win_size = win_length
51
+ cfg.n_fft = n_fft
52
+ cfg.n_mel = n_mels
53
+ cfg.f0_min = f0_min
54
+ cfg.f0_max = f0_max
55
+ cfg.pitch_bin = pitch_bin
56
+ cfg.pitch_max = pitch_max
57
+ cfg.pitch_min = pitch_min
58
+
59
+ # Extract mel spectrograms
60
+
61
+ cfg.fmin = 2000
62
+ cfg.fmax = 4000
63
+
64
+ mel1 = extract_mel_features(
65
+ y=audio.unsqueeze(0),
66
+ cfg=cfg,
67
+ ).squeeze(0)
68
+
69
+ cfg.fmin = 0
70
+ cfg.fmax = 2000
71
+
72
+ mel2 = extract_mel_features(
73
+ y=audio.unsqueeze(0),
74
+ cfg=cfg,
75
+ ).squeeze(0)
76
+
77
+ f0 = get_f0_features_using_parselmouth(
78
+ audio,
79
+ cfg,
80
+ )[0]
81
+
82
+ # Mel length alignment
83
+ length = min(len(f0), mel1.shape[-1])
84
+ f0 = f0[:length]
85
+ mel1 = mel1[:, :length]
86
+ mel2 = mel2[:, :length]
87
+
88
+ # Compute SPR
89
+ res = []
90
+
91
+ for i in range(mel1.shape[-1]):
92
+ if f0[i] <= 1:
93
+ continue
94
+
95
+ chunk1 = mel1[:, i]
96
+ chunk2 = mel2[:, i]
97
+
98
+ max1 = max(chunk1.numpy())
99
+ max2 = max(chunk2.numpy())
100
+
101
+ tmp_res = max2 - max1
102
+
103
+ res.append(tmp_res)
104
+
105
+ if len(res) == 0:
106
+ return False
107
+ else:
108
+ return sum(res) / len(res)
evaluation/metrics/__init__.py ADDED
File without changes
evaluation/metrics/energy/__init__.py ADDED
File without changes
evaluation/metrics/energy/energy_pearson_coefficients.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import librosa
8
+ import torch
9
+
10
+ import numpy as np
11
+ from numpy import linalg as LA
12
+
13
+ from torchmetrics import PearsonCorrCoef
14
+
15
+
16
+ def extract_energy_pearson_coeffcients(
17
+ audio_ref,
18
+ audio_deg,
19
+ fs=None,
20
+ n_fft=1024,
21
+ hop_length=256,
22
+ win_length=1024,
23
+ method="dtw",
24
+ db_scale=True,
25
+ ):
26
+ """Compute Energy Pearson Coefficients between the predicted and the ground truth audio.
27
+ audio_ref: path to the ground truth audio.
28
+ audio_deg: path to the predicted audio.
29
+ fs: sampling rate.
30
+ n_fft: fft size.
31
+ hop_length: hop length.
32
+ win_length: window length.
33
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
34
+ "cut" will cut both audios into a same length according to the one with the shorter length.
35
+ db_scale: the ground truth and predicted audio will be converted to db_scale if "True".
36
+ """
37
+ # Initialize method
38
+ pearson = PearsonCorrCoef()
39
+
40
+ # Load audio
41
+ if fs != None:
42
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
43
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
44
+ else:
45
+ audio_ref, fs = librosa.load(audio_ref)
46
+ audio_deg, fs = librosa.load(audio_deg)
47
+
48
+ # STFT
49
+ spec_ref = librosa.stft(
50
+ y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
51
+ )
52
+ spec_deg = librosa.stft(
53
+ y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
54
+ )
55
+
56
+ # Get magnitudes
57
+ mag_ref = np.abs(spec_ref).T
58
+ mag_deg = np.abs(spec_deg).T
59
+
60
+ # Convert spectrogram to energy
61
+ energy_ref = LA.norm(mag_ref, axis=1)
62
+ energy_deg = LA.norm(mag_deg, axis=1)
63
+
64
+ # Convert to db_scale
65
+ if db_scale:
66
+ energy_ref = 20 * np.log10(energy_ref)
67
+ energy_deg = 20 * np.log10(energy_deg)
68
+
69
+ # Audio length alignment
70
+ if method == "cut":
71
+ length = min(len(energy_ref), len(energy_deg))
72
+ energy_ref = energy_ref[:length]
73
+ energy_deg = energy_deg[:length]
74
+ elif method == "dtw":
75
+ _, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True)
76
+ energy_gt_new = []
77
+ energy_pred_new = []
78
+ for i in range(wp.shape[0]):
79
+ gt_index = wp[i][0]
80
+ pred_index = wp[i][1]
81
+ energy_gt_new.append(energy_ref[gt_index])
82
+ energy_pred_new.append(energy_deg[pred_index])
83
+ energy_ref = np.array(energy_gt_new)
84
+ energy_deg = np.array(energy_pred_new)
85
+ assert len(energy_ref) == len(energy_deg)
86
+
87
+ # Convert to tensor
88
+ energy_ref = torch.from_numpy(energy_ref)
89
+ energy_deg = torch.from_numpy(energy_deg)
90
+
91
+ return pearson(energy_ref, energy_deg).numpy().tolist()
evaluation/metrics/energy/energy_rmse.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import librosa
8
+ import torch
9
+
10
+ import numpy as np
11
+ from numpy import linalg as LA
12
+
13
+
14
+ def extract_energy_rmse(
15
+ audio_ref,
16
+ audio_deg,
17
+ fs=None,
18
+ n_fft=1024,
19
+ hop_length=256,
20
+ win_length=1024,
21
+ method="dtw",
22
+ db_scale=True,
23
+ ):
24
+ """Compute Energy Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
25
+ audio_ref: path to the ground truth audio.
26
+ audio_deg: path to the predicted audio.
27
+ fs: sampling rate.
28
+ n_fft: fft size.
29
+ hop_length: hop length.
30
+ win_length: window length.
31
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
32
+ "cut" will cut both audios into a same length according to the one with the shorter length.
33
+ db_scale: the ground truth and predicted audio will be converted to db_scale if "True".
34
+ """
35
+ # Load audio
36
+ if fs != None:
37
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
38
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
39
+ else:
40
+ audio_ref, fs = librosa.load(audio_ref)
41
+ audio_deg, fs = librosa.load(audio_deg)
42
+
43
+ # STFT
44
+ spec_ref = librosa.stft(
45
+ y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
46
+ )
47
+ spec_deg = librosa.stft(
48
+ y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
49
+ )
50
+
51
+ # Get magnitudes
52
+ mag_ref = np.abs(spec_ref).T
53
+ mag_deg = np.abs(spec_deg).T
54
+
55
+ # Convert spectrogram to energy
56
+ energy_ref = LA.norm(mag_ref, axis=1)
57
+ energy_deg = LA.norm(mag_deg, axis=1)
58
+
59
+ # Convert to db_scale
60
+ if db_scale:
61
+ energy_ref = 20 * np.log10(energy_ref)
62
+ energy_deg = 20 * np.log10(energy_deg)
63
+
64
+ # Audio length alignment
65
+ if method == "cut":
66
+ length = min(len(energy_ref), len(energy_deg))
67
+ energy_ref = energy_ref[:length]
68
+ energy_deg = energy_deg[:length]
69
+ elif method == "dtw":
70
+ _, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True)
71
+ energy_gt_new = []
72
+ energy_pred_new = []
73
+ for i in range(wp.shape[0]):
74
+ gt_index = wp[i][0]
75
+ pred_index = wp[i][1]
76
+ energy_gt_new.append(energy_ref[gt_index])
77
+ energy_pred_new.append(energy_deg[pred_index])
78
+ energy_ref = np.array(energy_gt_new)
79
+ energy_deg = np.array(energy_pred_new)
80
+ assert len(energy_ref) == len(energy_deg)
81
+
82
+ # Compute RMSE
83
+ energy_mse = np.square(np.subtract(energy_ref, energy_deg)).mean()
84
+ energy_rmse = math.sqrt(energy_mse)
85
+
86
+ return energy_rmse
evaluation/metrics/f0/__init__.py ADDED
File without changes
evaluation/metrics/f0/f0_pearson_coefficients.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import librosa
8
+
9
+ import numpy as np
10
+
11
+ from torchmetrics import PearsonCorrCoef
12
+
13
+ from utils.util import JsonHParams
14
+ from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
15
+
16
+
17
+ def extract_fpc(
18
+ audio_ref,
19
+ audio_deg,
20
+ fs=None,
21
+ hop_length=256,
22
+ f0_min=50,
23
+ f0_max=1100,
24
+ pitch_bin=256,
25
+ pitch_min=50,
26
+ pitch_max=1100,
27
+ need_mean=True,
28
+ method="dtw",
29
+ ):
30
+ """Compute F0 Pearson Distance (FPC) between the predicted and the ground truth audio.
31
+ audio_ref: path to the ground truth audio.
32
+ audio_deg: path to the predicted audio.
33
+ fs: sampling rate.
34
+ hop_length: hop length.
35
+ f0_min: lower limit for f0.
36
+ f0_max: upper limit for f0.
37
+ pitch_bin: number of bins for f0 quantization.
38
+ pitch_max: upper limit for f0 quantization.
39
+ pitch_min: lower limit for f0 quantization.
40
+ need_mean: subtract the mean value from f0 if "True".
41
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
42
+ "cut" will cut both audios into a same length according to the one with the shorter length.
43
+ """
44
+ # Initialize method
45
+ pearson = PearsonCorrCoef()
46
+
47
+ # Load audio
48
+ if fs != None:
49
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
50
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
51
+ else:
52
+ audio_ref, fs = librosa.load(audio_ref)
53
+ audio_deg, fs = librosa.load(audio_deg)
54
+
55
+ # Initialize config
56
+ cfg = JsonHParams()
57
+ cfg.sample_rate = fs
58
+ cfg.hop_size = hop_length
59
+ cfg.f0_min = f0_min
60
+ cfg.f0_max = f0_max
61
+ cfg.pitch_bin = pitch_bin
62
+ cfg.pitch_max = pitch_max
63
+ cfg.pitch_min = pitch_min
64
+
65
+ # Compute f0
66
+ f0_ref = get_f0_features_using_parselmouth(
67
+ audio_ref,
68
+ cfg,
69
+ )[0]
70
+
71
+ f0_deg = get_f0_features_using_parselmouth(
72
+ audio_deg,
73
+ cfg,
74
+ )[0]
75
+
76
+ # Subtract mean value from f0
77
+ if need_mean:
78
+ f0_ref = torch.from_numpy(f0_ref)
79
+ f0_deg = torch.from_numpy(f0_deg)
80
+
81
+ f0_ref = get_pitch_sub_median(f0_ref).numpy()
82
+ f0_deg = get_pitch_sub_median(f0_deg).numpy()
83
+
84
+ # Avoid silence
85
+ min_length = min(len(f0_ref), len(f0_deg))
86
+ if min_length <= 1:
87
+ return 1
88
+
89
+ # F0 length alignment
90
+ if method == "cut":
91
+ length = min(len(f0_ref), len(f0_deg))
92
+ f0_ref = f0_ref[:length]
93
+ f0_deg = f0_deg[:length]
94
+ elif method == "dtw":
95
+ _, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
96
+ f0_gt_new = []
97
+ f0_pred_new = []
98
+ for i in range(wp.shape[0]):
99
+ gt_index = wp[i][0]
100
+ pred_index = wp[i][1]
101
+ f0_gt_new.append(f0_ref[gt_index])
102
+ f0_pred_new.append(f0_deg[pred_index])
103
+ f0_ref = np.array(f0_gt_new)
104
+ f0_deg = np.array(f0_pred_new)
105
+ assert len(f0_ref) == len(f0_deg)
106
+
107
+ # Convert to tensor
108
+ f0_ref = torch.from_numpy(f0_ref)
109
+ f0_deg = torch.from_numpy(f0_deg)
110
+
111
+ return pearson(f0_ref, f0_deg).numpy().tolist()
evaluation/metrics/f0/f0_periodicity_rmse.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torchcrepe
7
+ import math
8
+ import librosa
9
+ import torch
10
+
11
+ import numpy as np
12
+
13
+
14
+ def extract_f0_periodicity_rmse(
15
+ audio_ref,
16
+ audio_deg,
17
+ fs=None,
18
+ hop_length=256,
19
+ method="dtw",
20
+ ):
21
+ """Compute f0 periodicity Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
22
+ audio_ref: path to the ground truth audio.
23
+ audio_deg: path to the predicted audio.
24
+ fs: sampling rate.
25
+ hop_length: hop length.
26
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
27
+ "cut" will cut both audios into a same length according to the one with the shorter length.
28
+ """
29
+ # Load audio
30
+ if fs != None:
31
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
32
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
33
+ else:
34
+ audio_ref, fs = librosa.load(audio_ref)
35
+ audio_deg, fs = librosa.load(audio_deg)
36
+
37
+ # Convert to torch
38
+ audio_ref = torch.from_numpy(audio_ref).unsqueeze(0)
39
+ audio_deg = torch.from_numpy(audio_deg).unsqueeze(0)
40
+
41
+ # Get periodicity
42
+ pitch_ref, periodicity_ref = torchcrepe.predict(
43
+ audio_ref,
44
+ sample_rate=fs,
45
+ hop_length=hop_length,
46
+ fmin=0,
47
+ fmax=1500,
48
+ model="full",
49
+ return_periodicity=True,
50
+ device="cuda:0",
51
+ )
52
+ pitch_deg, periodicity_deg = torchcrepe.predict(
53
+ audio_deg,
54
+ sample_rate=fs,
55
+ hop_length=hop_length,
56
+ fmin=0,
57
+ fmax=1500,
58
+ model="full",
59
+ return_periodicity=True,
60
+ device="cuda:0",
61
+ )
62
+
63
+ # Cut silence
64
+ periodicity_ref = (
65
+ torchcrepe.threshold.Silence()(
66
+ periodicity_ref,
67
+ audio_ref,
68
+ fs,
69
+ hop_length=hop_length,
70
+ )
71
+ .squeeze(0)
72
+ .numpy()
73
+ )
74
+ periodicity_deg = (
75
+ torchcrepe.threshold.Silence()(
76
+ periodicity_deg,
77
+ audio_deg,
78
+ fs,
79
+ hop_length=hop_length,
80
+ )
81
+ .squeeze(0)
82
+ .numpy()
83
+ )
84
+
85
+ # Avoid silence audio
86
+ min_length = min(len(periodicity_ref), len(periodicity_deg))
87
+ if min_length <= 1:
88
+ return 0
89
+
90
+ # Periodicity length alignment
91
+ if method == "cut":
92
+ length = min(len(periodicity_ref), len(periodicity_deg))
93
+ periodicity_ref = periodicity_ref[:length]
94
+ periodicity_deg = periodicity_deg[:length]
95
+ elif method == "dtw":
96
+ _, wp = librosa.sequence.dtw(periodicity_ref, periodicity_deg, backtrack=True)
97
+ periodicity_ref_new = []
98
+ periodicity_deg_new = []
99
+ for i in range(wp.shape[0]):
100
+ ref_index = wp[i][0]
101
+ deg_index = wp[i][1]
102
+ periodicity_ref_new.append(periodicity_ref[ref_index])
103
+ periodicity_deg_new.append(periodicity_deg[deg_index])
104
+ periodicity_ref = np.array(periodicity_ref_new)
105
+ periodicity_deg = np.array(periodicity_deg_new)
106
+ assert len(periodicity_ref) == len(periodicity_deg)
107
+
108
+ # Compute RMSE
109
+ periodicity_mse = np.square(np.subtract(periodicity_ref, periodicity_deg)).mean()
110
+ periodicity_rmse = math.sqrt(periodicity_mse)
111
+
112
+ return periodicity_rmse
evaluation/metrics/f0/f0_rmse.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import librosa
8
+ import torch
9
+
10
+ import numpy as np
11
+
12
+ from utils.util import JsonHParams
13
+ from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
14
+
15
+
16
+ ZERO = 1e-8
17
+
18
+
19
+ def extract_f0rmse(
20
+ audio_ref,
21
+ audio_deg,
22
+ fs=None,
23
+ hop_length=256,
24
+ f0_min=37,
25
+ f0_max=1000,
26
+ pitch_bin=256,
27
+ pitch_max=1100.0,
28
+ pitch_min=50.0,
29
+ need_mean=True,
30
+ method="dtw",
31
+ ):
32
+ """Compute F0 Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
33
+ audio_ref: path to the ground truth audio.
34
+ audio_deg: path to the predicted audio.
35
+ fs: sampling rate.
36
+ hop_length: hop length.
37
+ f0_min: lower limit for f0.
38
+ f0_max: upper limit for f0.
39
+ pitch_bin: number of bins for f0 quantization.
40
+ pitch_max: upper limit for f0 quantization.
41
+ pitch_min: lower limit for f0 quantization.
42
+ need_mean: subtract the mean value from f0 if "True".
43
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
44
+ "cut" will cut both audios into a same length according to the one with the shorter length.
45
+ """
46
+ # Load audio
47
+ if fs != None:
48
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
49
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
50
+ else:
51
+ audio_ref, fs = librosa.load(audio_ref)
52
+ audio_deg, fs = librosa.load(audio_deg)
53
+
54
+ # Initialize config for f0 extraction
55
+ cfg = JsonHParams()
56
+ cfg.sample_rate = fs
57
+ cfg.hop_size = hop_length
58
+ cfg.f0_min = f0_min
59
+ cfg.f0_max = f0_max
60
+ cfg.pitch_bin = pitch_bin
61
+ cfg.pitch_max = pitch_max
62
+ cfg.pitch_min = pitch_min
63
+
64
+ # Extract f0
65
+ f0_ref = get_f0_features_using_parselmouth(
66
+ audio_ref,
67
+ cfg,
68
+ )[0]
69
+
70
+ f0_deg = get_f0_features_using_parselmouth(
71
+ audio_deg,
72
+ cfg,
73
+ )[0]
74
+
75
+ # Subtract mean value from f0
76
+ if need_mean:
77
+ f0_ref = torch.from_numpy(f0_ref)
78
+ f0_deg = torch.from_numpy(f0_deg)
79
+
80
+ f0_ref = get_pitch_sub_median(f0_ref).numpy()
81
+ f0_deg = get_pitch_sub_median(f0_deg).numpy()
82
+
83
+ # Avoid silence
84
+ min_length = min(len(f0_ref), len(f0_deg))
85
+ if min_length <= 1:
86
+ return 0
87
+
88
+ # F0 length alignment
89
+ if method == "cut":
90
+ length = min(len(f0_ref), len(f0_deg))
91
+ f0_ref = f0_ref[:length]
92
+ f0_deg = f0_deg[:length]
93
+ elif method == "dtw":
94
+ _, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
95
+ f0_gt_new = []
96
+ f0_pred_new = []
97
+ for i in range(wp.shape[0]):
98
+ gt_index = wp[i][0]
99
+ pred_index = wp[i][1]
100
+ f0_gt_new.append(f0_ref[gt_index])
101
+ f0_pred_new.append(f0_deg[pred_index])
102
+ f0_ref = np.array(f0_gt_new)
103
+ f0_deg = np.array(f0_pred_new)
104
+ assert len(f0_ref) == len(f0_deg)
105
+
106
+ # Compute RMSE
107
+ f0_mse = np.square(np.subtract(f0_ref, f0_deg)).mean()
108
+ f0_rmse = math.sqrt(f0_mse)
109
+
110
+ return f0_rmse
evaluation/metrics/f0/v_uv_f1.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import librosa
8
+ import torch
9
+
10
+ import numpy as np
11
+
12
+ from utils.util import JsonHParams
13
+ from utils.f0 import get_f0_features_using_parselmouth
14
+
15
+
16
+ ZERO = 1e-8
17
+
18
+
19
+ def extract_f1_v_uv(
20
+ audio_ref,
21
+ audio_deg,
22
+ fs=None,
23
+ hop_length=256,
24
+ f0_min=37,
25
+ f0_max=1000,
26
+ pitch_bin=256,
27
+ pitch_max=1100.0,
28
+ pitch_min=50.0,
29
+ method="dtw",
30
+ ):
31
+ """Compute F1 socre of voiced/unvoiced accuracy between the predicted and the ground truth audio.
32
+ audio_ref: path to the ground truth audio.
33
+ audio_deg: path to the predicted audio.
34
+ fs: sampling rate.
35
+ hop_length: hop length.
36
+ f0_min: lower limit for f0.
37
+ f0_max: upper limit for f0.
38
+ pitch_bin: number of bins for f0 quantization.
39
+ pitch_max: upper limit for f0 quantization.
40
+ pitch_min: lower limit for f0 quantization.
41
+ need_mean: subtract the mean value from f0 if "True".
42
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
43
+ "cut" will cut both audios into a same length according to the one with the shorter length.
44
+ """
45
+ # Load audio
46
+ if fs != None:
47
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
48
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
49
+ else:
50
+ audio_ref, fs = librosa.load(audio_ref)
51
+ audio_deg, fs = librosa.load(audio_deg)
52
+
53
+ # Initialize config
54
+ cfg = JsonHParams()
55
+ cfg.sample_rate = fs
56
+ cfg.hop_size = hop_length
57
+ cfg.f0_min = f0_min
58
+ cfg.f0_max = f0_max
59
+ cfg.pitch_bin = pitch_bin
60
+ cfg.pitch_max = pitch_max
61
+ cfg.pitch_min = pitch_min
62
+
63
+ # Compute f0
64
+ f0_ref = get_f0_features_using_parselmouth(
65
+ audio_ref,
66
+ cfg,
67
+ )[0]
68
+
69
+ f0_deg = get_f0_features_using_parselmouth(
70
+ audio_deg,
71
+ cfg,
72
+ )[0]
73
+
74
+ # Avoid silence
75
+ min_length = min(len(f0_ref), len(f0_deg))
76
+ if min_length <= 1:
77
+ return 0, 0, 0
78
+
79
+ # F0 length alignment
80
+ if method == "cut":
81
+ length = min(len(f0_ref), len(f0_deg))
82
+ f0_ref = f0_ref[:length]
83
+ f0_deg = f0_deg[:length]
84
+ elif method == "dtw":
85
+ _, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
86
+ f0_gt_new = []
87
+ f0_pred_new = []
88
+ for i in range(wp.shape[0]):
89
+ gt_index = wp[i][0]
90
+ pred_index = wp[i][1]
91
+ f0_gt_new.append(f0_ref[gt_index])
92
+ f0_pred_new.append(f0_deg[pred_index])
93
+ f0_ref = np.array(f0_gt_new)
94
+ f0_deg = np.array(f0_pred_new)
95
+ assert len(f0_ref) == len(f0_deg)
96
+
97
+ # Get voiced/unvoiced parts
98
+ ref_voiced = torch.Tensor([f0_ref != 0]).bool()
99
+ deg_voiced = torch.Tensor([f0_deg != 0]).bool()
100
+
101
+ # Compute TP, FP, FN
102
+ true_postives = (ref_voiced & deg_voiced).sum()
103
+ false_postives = (~ref_voiced & deg_voiced).sum()
104
+ false_negatives = (ref_voiced & ~deg_voiced).sum()
105
+
106
+ return (
107
+ true_postives.numpy().tolist(),
108
+ false_postives.numpy().tolist(),
109
+ false_negatives.numpy().tolist(),
110
+ )
evaluation/metrics/intelligibility/__init__.py ADDED
File without changes
evaluation/metrics/intelligibility/character_error_rate.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import whisper
7
+
8
+ from torchmetrics import CharErrorRate
9
+
10
+
11
+ def extract_cer(
12
+ content_gt=None,
13
+ audio_ref=None,
14
+ audio_deg=None,
15
+ fs=None,
16
+ language="chinese",
17
+ remove_space=True,
18
+ remove_punctuation=True,
19
+ mode="gt_audio",
20
+ ):
21
+ """Compute Character Error Rate (CER) between the predicted and the ground truth audio.
22
+ content_gt: the ground truth content.
23
+ audio_ref: path to the ground truth audio.
24
+ audio_deg: path to the predicted audio.
25
+ mode: "gt_content" computes the CER between the predicted content obtained from the whisper model and the ground truth content.
26
+ both content_gt and audio_deg are needed.
27
+ "gt_audio" computes the CER between the extracted ground truth and predicted contents obtained from the whisper model.
28
+ both audio_ref and audio_deg are needed.
29
+ """
30
+ # Get ground truth content
31
+ if mode == "gt_content":
32
+ assert content_gt != None
33
+ if language == "chinese":
34
+ prompt = "以下是普通话的句子"
35
+ model = whisper.load_model("large").cuda()
36
+ result_deg = model.transcribe(
37
+ audio_deg, language="zh", verbose=True, initial_prompt=prompt
38
+ )
39
+ elif language == "english":
40
+ model = whisper.load_model("large").cuda()
41
+ result_deg = model.transcribe(audio_deg, language="en", verbose=True)
42
+ elif mode == "gt_audio":
43
+ assert audio_ref != None
44
+ if language == "chinese":
45
+ prompt = "以下是普通话的句子"
46
+ model = whisper.load_model("large").cuda()
47
+ result_ref = model.transcribe(
48
+ audio_ref, language="zh", verbose=True, initial_prompt=prompt
49
+ )
50
+ result_deg = model.transcribe(
51
+ audio_deg, language="zh", verbose=True, initial_prompt=prompt
52
+ )
53
+ elif language == "english":
54
+ model = whisper.load_model("large").cuda()
55
+ result_ref = model.transcribe(audio_deg, language="en", verbose=True)
56
+ result_deg = model.transcribe(audio_deg, language="en", verbose=True)
57
+ content_gt = result_ref["text"]
58
+ if remove_space:
59
+ content_gt = content_gt.replace(" ", "")
60
+ if remove_punctuation:
61
+ content_gt = content_gt.replace(".", "")
62
+ content_gt = content_gt.replace("'", "")
63
+ content_gt = content_gt.replace("-", "")
64
+ content_gt = content_gt.replace(",", "")
65
+ content_gt = content_gt.replace("!", "")
66
+ content_gt = content_gt.lower()
67
+
68
+ # Get predicted truth content
69
+ content_pred = result_deg["text"]
70
+ if remove_space:
71
+ content_pred = content_pred.replace(" ", "")
72
+ if remove_punctuation:
73
+ content_pred = content_pred.replace(".", "")
74
+ content_pred = content_pred.replace("'", "")
75
+ content_pred = content_pred.replace("-", "")
76
+ content_pred = content_pred.replace(",", "")
77
+ content_pred = content_pred.replace("!", "")
78
+ content_pred = content_pred.lower()
79
+ cer = CharErrorRate()
80
+
81
+ return cer(content_pred, content_gt).numpy().tolist()
evaluation/metrics/intelligibility/word_error_rate.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import whisper
7
+
8
+ from torchmetrics import WordErrorRate
9
+
10
+
11
+ def extract_wer(
12
+ content_gt=None,
13
+ audio_ref=None,
14
+ audio_deg=None,
15
+ fs=None,
16
+ language="chinese",
17
+ remove_space=True,
18
+ remove_punctuation=True,
19
+ mode="gt_audio",
20
+ ):
21
+ """Compute Word Error Rate (WER) between the predicted and the ground truth audio.
22
+ content_gt: the ground truth content.
23
+ audio_ref: path to the ground truth audio.
24
+ audio_deg: path to the predicted audio.
25
+ mode: "gt_content" computes the WER between the predicted content obtained from the whisper model and the ground truth content.
26
+ both content_gt and audio_deg are needed.
27
+ "gt_audio" computes the WER between the extracted ground truth and predicted contents obtained from the whisper model.
28
+ both audio_ref and audio_deg are needed.
29
+ """
30
+ # Get ground truth content
31
+ if mode == "gt_content":
32
+ assert content_gt != None
33
+ if language == "chinese":
34
+ prompt = "以下是普通话的句子"
35
+ model = whisper.load_model("large").cuda()
36
+ result_deg = model.transcribe(
37
+ audio_deg, language="zh", verbose=True, initial_prompt=prompt
38
+ )
39
+ elif language == "english":
40
+ model = whisper.load_model("large").cuda()
41
+ result_deg = model.transcribe(audio_deg, language="en", verbose=True)
42
+ elif mode == "gt_audio":
43
+ assert audio_ref != None
44
+ if language == "chinese":
45
+ prompt = "以下是普通话的句子"
46
+ model = whisper.load_model("large").cuda()
47
+ result_ref = model.transcribe(
48
+ audio_ref, language="zh", verbose=True, initial_prompt=prompt
49
+ )
50
+ result_deg = model.transcribe(
51
+ audio_deg, language="zh", verbose=True, initial_prompt=prompt
52
+ )
53
+ elif language == "english":
54
+ model = whisper.load_model("large").cuda()
55
+ result_ref = model.transcribe(audio_deg, language="en", verbose=True)
56
+ result_deg = model.transcribe(audio_deg, language="en", verbose=True)
57
+ content_gt = result_ref["text"]
58
+ if remove_space:
59
+ content_gt = content_gt.replace(" ", "")
60
+ if remove_punctuation:
61
+ content_gt = content_gt.replace(".", "")
62
+ content_gt = content_gt.replace("'", "")
63
+ content_gt = content_gt.replace("-", "")
64
+ content_gt = content_gt.replace(",", "")
65
+ content_gt = content_gt.replace("!", "")
66
+ content_gt = content_gt.lower()
67
+
68
+ # Get predicted content
69
+ content_pred = result_deg["text"]
70
+ if remove_space:
71
+ content_pred = content_pred.replace(" ", "")
72
+ if remove_punctuation:
73
+ content_pred = content_pred.replace(".", "")
74
+ content_pred = content_pred.replace("'", "")
75
+ content_pred = content_pred.replace("-", "")
76
+ content_pred = content_pred.replace(",", "")
77
+ content_pred = content_pred.replace("!", "")
78
+ content_pred = content_pred.lower()
79
+ wer = WordErrorRate()
80
+
81
+ return wer(content_pred, content_gt).numpy().tolist()
evaluation/metrics/similarity/__init__.py ADDED
File without changes
evaluation/metrics/similarity/models/RawNetBasicBlock.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class PreEmphasis(torch.nn.Module):
14
+ def __init__(self, coef: float = 0.97) -> None:
15
+ super().__init__()
16
+ self.coef = coef
17
+ # make kernel
18
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
19
+ self.register_buffer(
20
+ "flipped_filter",
21
+ torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
22
+ )
23
+
24
+ def forward(self, input: torch.tensor) -> torch.tensor:
25
+ assert (
26
+ len(input.size()) == 2
27
+ ), "The number of dimensions of input tensor must be 2!"
28
+ # reflect padding to match lengths of in/out
29
+ input = input.unsqueeze(1)
30
+ input = F.pad(input, (1, 0), "reflect")
31
+ return F.conv1d(input, self.flipped_filter)
32
+
33
+
34
+ class AFMS(nn.Module):
35
+ """
36
+ Alpha-Feature map scaling, added to the output of each residual block[1,2].
37
+
38
+ Reference:
39
+ [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
40
+ [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
41
+ """
42
+
43
+ def __init__(self, nb_dim: int) -> None:
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
46
+ self.fc = nn.Linear(nb_dim, nb_dim)
47
+ self.sig = nn.Sigmoid()
48
+
49
+ def forward(self, x):
50
+ y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
51
+ y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
52
+
53
+ x = x + self.alpha
54
+ x = x * y
55
+ return x
56
+
57
+
58
+ class Bottle2neck(nn.Module):
59
+ def __init__(
60
+ self,
61
+ inplanes,
62
+ planes,
63
+ kernel_size=None,
64
+ dilation=None,
65
+ scale=4,
66
+ pool=False,
67
+ ):
68
+ super().__init__()
69
+
70
+ width = int(math.floor(planes / scale))
71
+
72
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
73
+ self.bn1 = nn.BatchNorm1d(width * scale)
74
+
75
+ self.nums = scale - 1
76
+
77
+ convs = []
78
+ bns = []
79
+
80
+ num_pad = math.floor(kernel_size / 2) * dilation
81
+
82
+ for i in range(self.nums):
83
+ convs.append(
84
+ nn.Conv1d(
85
+ width,
86
+ width,
87
+ kernel_size=kernel_size,
88
+ dilation=dilation,
89
+ padding=num_pad,
90
+ )
91
+ )
92
+ bns.append(nn.BatchNorm1d(width))
93
+
94
+ self.convs = nn.ModuleList(convs)
95
+ self.bns = nn.ModuleList(bns)
96
+
97
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
98
+ self.bn3 = nn.BatchNorm1d(planes)
99
+
100
+ self.relu = nn.ReLU()
101
+
102
+ self.width = width
103
+
104
+ self.mp = nn.MaxPool1d(pool) if pool else False
105
+ self.afms = AFMS(planes)
106
+
107
+ if inplanes != planes: # if change in number of filters
108
+ self.residual = nn.Sequential(
109
+ nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
110
+ )
111
+ else:
112
+ self.residual = nn.Identity()
113
+
114
+ def forward(self, x):
115
+ residual = self.residual(x)
116
+
117
+ out = self.conv1(x)
118
+ out = self.relu(out)
119
+ out = self.bn1(out)
120
+
121
+ spx = torch.split(out, self.width, 1)
122
+ for i in range(self.nums):
123
+ if i == 0:
124
+ sp = spx[i]
125
+ else:
126
+ sp = sp + spx[i]
127
+ sp = self.convs[i](sp)
128
+ sp = self.relu(sp)
129
+ sp = self.bns[i](sp)
130
+ if i == 0:
131
+ out = sp
132
+ else:
133
+ out = torch.cat((out, sp), 1)
134
+
135
+ out = torch.cat((out, spx[self.nums]), 1)
136
+
137
+ out = self.conv3(out)
138
+ out = self.relu(out)
139
+ out = self.bn3(out)
140
+
141
+ out += residual
142
+ if self.mp:
143
+ out = self.mp(out)
144
+ out = self.afms(out)
145
+
146
+ return out
evaluation/metrics/similarity/models/RawNetModel.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # -*- encoding: utf-8 -*-
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from asteroid_filterbanks import Encoder, ParamSincFB
11
+
12
+ from .RawNetBasicBlock import Bottle2neck, PreEmphasis
13
+
14
+
15
+ class RawNet3(nn.Module):
16
+ def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
17
+ super().__init__()
18
+
19
+ nOut = kwargs["nOut"]
20
+
21
+ self.context = context
22
+ self.encoder_type = kwargs["encoder_type"]
23
+ self.log_sinc = kwargs["log_sinc"]
24
+ self.norm_sinc = kwargs["norm_sinc"]
25
+ self.out_bn = kwargs["out_bn"]
26
+ self.summed = summed
27
+
28
+ self.preprocess = nn.Sequential(
29
+ PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
30
+ )
31
+ self.conv1 = Encoder(
32
+ ParamSincFB(
33
+ C // 4,
34
+ 251,
35
+ stride=kwargs["sinc_stride"],
36
+ )
37
+ )
38
+ self.relu = nn.ReLU()
39
+ self.bn1 = nn.BatchNorm1d(C // 4)
40
+
41
+ self.layer1 = block(
42
+ C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
43
+ )
44
+ self.layer2 = block(C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3)
45
+ self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
46
+ self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
47
+
48
+ if self.context:
49
+ attn_input = 1536 * 3
50
+ else:
51
+ attn_input = 1536
52
+ print("self.encoder_type", self.encoder_type)
53
+ if self.encoder_type == "ECA":
54
+ attn_output = 1536
55
+ elif self.encoder_type == "ASP":
56
+ attn_output = 1
57
+ else:
58
+ raise ValueError("Undefined encoder")
59
+
60
+ self.attention = nn.Sequential(
61
+ nn.Conv1d(attn_input, 128, kernel_size=1),
62
+ nn.ReLU(),
63
+ nn.BatchNorm1d(128),
64
+ nn.Conv1d(128, attn_output, kernel_size=1),
65
+ nn.Softmax(dim=2),
66
+ )
67
+
68
+ self.bn5 = nn.BatchNorm1d(3072)
69
+
70
+ self.fc6 = nn.Linear(3072, nOut)
71
+ self.bn6 = nn.BatchNorm1d(nOut)
72
+
73
+ self.mp3 = nn.MaxPool1d(3)
74
+
75
+ def forward(self, x):
76
+ """
77
+ :param x: input mini-batch (bs, samp)
78
+ """
79
+
80
+ with torch.cuda.amp.autocast(enabled=False):
81
+ x = self.preprocess(x)
82
+ x = torch.abs(self.conv1(x))
83
+ if self.log_sinc:
84
+ x = torch.log(x + 1e-6)
85
+ if self.norm_sinc == "mean":
86
+ x = x - torch.mean(x, dim=-1, keepdim=True)
87
+ elif self.norm_sinc == "mean_std":
88
+ m = torch.mean(x, dim=-1, keepdim=True)
89
+ s = torch.std(x, dim=-1, keepdim=True)
90
+ s[s < 0.001] = 0.001
91
+ x = (x - m) / s
92
+
93
+ if self.summed:
94
+ x1 = self.layer1(x)
95
+ x2 = self.layer2(x1)
96
+ x3 = self.layer3(self.mp3(x1) + x2)
97
+ else:
98
+ x1 = self.layer1(x)
99
+ x2 = self.layer2(x1)
100
+ x3 = self.layer3(x2)
101
+
102
+ x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
103
+ x = self.relu(x)
104
+
105
+ t = x.size()[-1]
106
+
107
+ if self.context:
108
+ global_x = torch.cat(
109
+ (
110
+ x,
111
+ torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
112
+ torch.sqrt(
113
+ torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4)
114
+ ).repeat(1, 1, t),
115
+ ),
116
+ dim=1,
117
+ )
118
+ else:
119
+ global_x = x
120
+
121
+ w = self.attention(global_x)
122
+
123
+ mu = torch.sum(x * w, dim=2)
124
+ sg = torch.sqrt(
125
+ (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
126
+ )
127
+
128
+ x = torch.cat((mu, sg), 1)
129
+
130
+ x = self.bn5(x)
131
+
132
+ x = self.fc6(x)
133
+
134
+ if self.out_bn:
135
+ x = self.bn6(x)
136
+
137
+ return x
138
+
139
+
140
+ def MainModel(**kwargs):
141
+ model = RawNet3(Bottle2neck, model_scale=8, context=True, summed=True, **kwargs)
142
+ return model
evaluation/metrics/similarity/models/__init__.py ADDED
File without changes
evaluation/metrics/similarity/speaker_similarity.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+ import librosa
14
+
15
+ from .models.RawNetModel import RawNet3
16
+ from .models.RawNetBasicBlock import Bottle2neck
17
+
18
+
19
+ def extract_speaker_embd(
20
+ model, fn: str, n_samples: int, n_segments: int = 10, gpu: bool = False
21
+ ) -> np.ndarray:
22
+ audio, sample_rate = sf.read(fn)
23
+ if len(audio.shape) > 1:
24
+ raise ValueError(
25
+ f"RawNet3 supports mono input only. Input data has a shape of {audio.shape}."
26
+ )
27
+
28
+ if sample_rate != 16000:
29
+ # resample to 16000kHz
30
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
31
+ # print("resample to 16000kHz!")
32
+ if len(audio) < n_samples: # RawNet3 was trained using utterances of 3 seconds
33
+ shortage = n_samples - len(audio) + 1
34
+ audio = np.pad(audio, (0, shortage), "wrap")
35
+
36
+ audios = []
37
+ startframe = np.linspace(0, len(audio) - n_samples, num=n_segments)
38
+ for asf in startframe:
39
+ audios.append(audio[int(asf) : int(asf) + n_samples])
40
+
41
+ audios = torch.from_numpy(np.stack(audios, axis=0).astype(np.float32))
42
+ if gpu:
43
+ audios = audios.to("cuda")
44
+ with torch.no_grad():
45
+ output = model(audios)
46
+
47
+ return output
48
+
49
+
50
+ def extract_speaker_similarity(target_path, reference_path):
51
+ model = RawNet3(
52
+ Bottle2neck,
53
+ model_scale=8,
54
+ context=True,
55
+ summed=True,
56
+ encoder_type="ECA",
57
+ nOut=256,
58
+ out_bn=False,
59
+ sinc_stride=10,
60
+ log_sinc=True,
61
+ norm_sinc="mean",
62
+ grad_mult=1,
63
+ )
64
+
65
+ gpu = False
66
+ model.load_state_dict(
67
+ torch.load(
68
+ "pretrained/rawnet3/model.pt",
69
+ map_location=lambda storage, loc: storage,
70
+ )["model"]
71
+ )
72
+ model.eval()
73
+ print("RawNet3 initialised & weights loaded!")
74
+
75
+ if torch.cuda.is_available():
76
+ print("Cuda available, conducting inference on GPU")
77
+ model = model.to("cuda")
78
+ gpu = True
79
+ # for target_path, reference_path in zip(target_paths, ref_paths):
80
+ # print(f"Extracting embeddings for target singers...")
81
+
82
+ target_embeddings = []
83
+ for file in tqdm(os.listdir(target_path)):
84
+ output = extract_speaker_embd(
85
+ model,
86
+ fn=os.path.join(target_path, file),
87
+ n_samples=48000,
88
+ n_segments=10,
89
+ gpu=gpu,
90
+ ).mean(0)
91
+ target_embeddings.append(output.detach().cpu().numpy())
92
+ target_embeddings = np.array(target_embeddings)
93
+ target_embedding = np.mean(target_embeddings, axis=0)
94
+
95
+ # print(f"Extracting embeddings for reference singer...")
96
+
97
+ reference_embeddings = []
98
+ for file in tqdm(os.listdir(reference_path)):
99
+ output = extract_speaker_embd(
100
+ model,
101
+ fn=os.path.join(reference_path, file),
102
+ n_samples=48000,
103
+ n_segments=10,
104
+ gpu=gpu,
105
+ ).mean(0)
106
+ reference_embeddings.append(output.detach().cpu().numpy())
107
+ reference_embeddings = np.array(reference_embeddings)
108
+
109
+ # print("Calculating cosine similarity...")
110
+
111
+ cos_sim = F.cosine_similarity(
112
+ torch.from_numpy(np.mean(target_embeddings, axis=0)).unsqueeze(0),
113
+ torch.from_numpy(np.mean(reference_embeddings, axis=0)).unsqueeze(0),
114
+ dim=1,
115
+ )
116
+
117
+ # print(f"Mean cosine similarity: {cos_sim.item()}")
118
+
119
+ return cos_sim.item()
evaluation/metrics/spectrogram/__init__.py ADDED
File without changes
evaluation/metrics/spectrogram/frechet_distance.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from frechet_audio_distance import FrechetAudioDistance
7
+
8
+
9
+ def extract_fad(
10
+ audio_dir1,
11
+ audio_dir2,
12
+ mode="vggish",
13
+ use_pca=False,
14
+ use_activation=False,
15
+ verbose=False,
16
+ ):
17
+ """Extract Frechet Audio Distance for two given audio folders.
18
+ audio_dir1: path to the ground truth audio folder.
19
+ audio_dir2: path to the predicted audio folder.
20
+ mode: "vggish", "pann", "clap" for different models.
21
+ """
22
+ frechet = FrechetAudioDistance(
23
+ model_name=mode,
24
+ use_pca=use_pca,
25
+ use_activation=use_activation,
26
+ verbose=verbose,
27
+ )
28
+
29
+ fad_score = frechet.score(audio_dir1, audio_dir2)
30
+
31
+ return fad_score
evaluation/metrics/spectrogram/mel_cepstral_distortion.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from pymcd.mcd import Calculate_MCD
7
+
8
+
9
+ def extract_mcd(audio_ref, audio_deg, fs=None, mode="dtw_sl"):
10
+ """Extract Mel-Cepstral Distance for a two given audio.
11
+ Args:
12
+ audio_ref: The given reference audio. It is an audio path.
13
+ audio_deg: The given synthesized audio. It is an audio path.
14
+ mode: "plain", "dtw" and "dtw_sl".
15
+ """
16
+ mcd_toolbox = Calculate_MCD(MCD_mode=mode)
17
+ if fs != None:
18
+ mcd_toolbox.SAMPLING_RATE = fs
19
+ mcd_value = mcd_toolbox.calculate_mcd(audio_ref, audio_deg)
20
+
21
+ return mcd_value
evaluation/metrics/spectrogram/multi_resolution_stft_distance.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import librosa
7
+ import torch
8
+
9
+ import numpy as np
10
+
11
+
12
+ def extract_mstft(
13
+ audio_ref,
14
+ audio_deg,
15
+ fs=None,
16
+ mid_freq=None,
17
+ high_freq=None,
18
+ method="cut",
19
+ version="pwg",
20
+ ):
21
+ """Compute Multi-Scale STFT Distance (mstft) between the predicted and the ground truth audio.
22
+ audio_ref: path to the ground truth audio.
23
+ audio_deg: path to the predicted audio.
24
+ fs: sampling rate.
25
+ med_freq: division frequency for mid frequency parts.
26
+ high_freq: division frequency for high frequency parts.
27
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
28
+ "cut" will cut both audios into a same length according to the one with the shorter length.
29
+ version: "pwg" will use the computational version provided by ParallelWaveGAN.
30
+ "encodec" will use the computational version provided by Encodec.
31
+ """
32
+ # Load audio
33
+ if fs != None:
34
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
35
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
36
+ else:
37
+ audio_ref, fs = librosa.load(audio_ref)
38
+ audio_deg, fs = librosa.load(audio_deg)
39
+
40
+ # Automatically choose mid_freq and high_freq if they are not given
41
+ if mid_freq == None:
42
+ mid_freq = fs // 6
43
+ if high_freq == None:
44
+ high_freq = fs // 3
45
+
46
+ # Audio length alignment
47
+ if len(audio_ref) != len(audio_deg):
48
+ if method == "cut":
49
+ length = min(len(audio_ref), len(audio_deg))
50
+ audio_ref = audio_ref[:length]
51
+ audio_deg = audio_deg[:length]
52
+ elif method == "dtw":
53
+ _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
54
+ audio_ref_new = []
55
+ audio_deg_new = []
56
+ for i in range(wp.shape[0]):
57
+ ref_index = wp[i][0]
58
+ deg_index = wp[i][1]
59
+ audio_ref_new.append(audio_ref[ref_index])
60
+ audio_deg_new.append(audio_deg[deg_index])
61
+ audio_ref = np.array(audio_ref_new)
62
+ audio_deg = np.array(audio_deg_new)
63
+ assert len(audio_ref) == len(audio_deg)
64
+
65
+ # Define loss function
66
+ l1Loss = torch.nn.L1Loss(reduction="mean")
67
+ l2Loss = torch.nn.MSELoss(reduction="mean")
68
+
69
+ # Compute distance
70
+ if version == "encodec":
71
+ n_fft = 1024
72
+
73
+ mstft = 0
74
+ mstft_low = 0
75
+ mstft_mid = 0
76
+ mstft_high = 0
77
+
78
+ freq_resolution = fs / n_fft
79
+ mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
80
+ high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
81
+
82
+ for i in range(5, 11):
83
+ hop_length = 2**i // 4
84
+ win_length = 2**i
85
+
86
+ spec_ref = librosa.stft(
87
+ y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
88
+ )
89
+ spec_deg = librosa.stft(
90
+ y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
91
+ )
92
+
93
+ mag_ref = np.abs(spec_ref)
94
+ mag_deg = np.abs(spec_deg)
95
+
96
+ mag_ref = torch.from_numpy(mag_ref)
97
+ mag_deg = torch.from_numpy(mag_deg)
98
+ mstft += l1Loss(mag_ref, mag_deg) + l2Loss(mag_ref, mag_deg)
99
+
100
+ mag_ref_low = mag_ref[:mid_freq_index, :]
101
+ mag_deg_low = mag_deg[:mid_freq_index, :]
102
+ mstft_low += l1Loss(mag_ref_low, mag_deg_low) + l2Loss(
103
+ mag_ref_low, mag_deg_low
104
+ )
105
+
106
+ mag_ref_mid = mag_ref[mid_freq_index:high_freq_index, :]
107
+ mag_deg_mid = mag_deg[mid_freq_index:high_freq_index, :]
108
+ mstft_mid += l1Loss(mag_ref_mid, mag_deg_mid) + l2Loss(
109
+ mag_ref_mid, mag_deg_mid
110
+ )
111
+
112
+ mag_ref_high = mag_ref[high_freq_index:, :]
113
+ mag_deg_high = mag_deg[high_freq_index:, :]
114
+ mstft_high += l1Loss(mag_ref_high, mag_deg_high) + l2Loss(
115
+ mag_ref_high, mag_deg_high
116
+ )
117
+
118
+ mstft /= 6
119
+ mstft_low /= 6
120
+ mstft_mid /= 6
121
+ mstft_high /= 6
122
+
123
+ return mstft
124
+ elif version == "pwg":
125
+ fft_sizes = [1024, 2048, 512]
126
+ hop_sizes = [120, 240, 50]
127
+ win_sizes = [600, 1200, 240]
128
+
129
+ audio_ref = torch.from_numpy(audio_ref)
130
+ audio_deg = torch.from_numpy(audio_deg)
131
+
132
+ mstft_sc = 0
133
+ mstft_sc_low = 0
134
+ mstft_sc_mid = 0
135
+ mstft_sc_high = 0
136
+
137
+ mstft_mag = 0
138
+ mstft_mag_low = 0
139
+ mstft_mag_mid = 0
140
+ mstft_mag_high = 0
141
+
142
+ for n_fft, hop_length, win_length in zip(fft_sizes, hop_sizes, win_sizes):
143
+ spec_ref = torch.stft(
144
+ audio_ref, n_fft, hop_length, win_length, return_complex=False
145
+ )
146
+ spec_deg = torch.stft(
147
+ audio_deg, n_fft, hop_length, win_length, return_complex=False
148
+ )
149
+
150
+ real_ref = spec_ref[..., 0]
151
+ imag_ref = spec_ref[..., 1]
152
+ real_deg = spec_deg[..., 0]
153
+ imag_deg = spec_deg[..., 1]
154
+
155
+ mag_ref = torch.sqrt(
156
+ torch.clamp(real_ref**2 + imag_ref**2, min=1e-7)
157
+ ).transpose(1, 0)
158
+ mag_deg = torch.sqrt(
159
+ torch.clamp(real_deg**2 + imag_deg**2, min=1e-7)
160
+ ).transpose(1, 0)
161
+ sc_loss = torch.norm(mag_ref - mag_deg, p="fro") / torch.norm(
162
+ mag_ref, p="fro"
163
+ )
164
+ mag_loss = l1Loss(torch.log(mag_ref), torch.log(mag_deg))
165
+
166
+ mstft_sc += sc_loss
167
+ mstft_mag += mag_loss
168
+
169
+ freq_resolution = fs / n_fft
170
+ mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
171
+ high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
172
+
173
+ mag_ref_low = mag_ref[:, :mid_freq_index]
174
+ mag_deg_low = mag_deg[:, :mid_freq_index]
175
+ sc_loss_low = torch.norm(mag_ref_low - mag_deg_low, p="fro") / torch.norm(
176
+ mag_ref_low, p="fro"
177
+ )
178
+ mag_loss_low = l1Loss(torch.log(mag_ref_low), torch.log(mag_deg_low))
179
+
180
+ mstft_sc_low += sc_loss_low
181
+ mstft_mag_low += mag_loss_low
182
+
183
+ mag_ref_mid = mag_ref[:, mid_freq_index:high_freq_index]
184
+ mag_deg_mid = mag_deg[:, mid_freq_index:high_freq_index]
185
+ sc_loss_mid = torch.norm(mag_ref_mid - mag_deg_mid, p="fro") / torch.norm(
186
+ mag_ref_mid, p="fro"
187
+ )
188
+ mag_loss_mid = l1Loss(torch.log(mag_ref_mid), torch.log(mag_deg_mid))
189
+
190
+ mstft_sc_mid += sc_loss_mid
191
+ mstft_mag_mid += mag_loss_mid
192
+
193
+ mag_ref_high = mag_ref[:, high_freq_index:]
194
+ mag_deg_high = mag_deg[:, high_freq_index:]
195
+ sc_loss_high = torch.norm(
196
+ mag_ref_high - mag_deg_high, p="fro"
197
+ ) / torch.norm(mag_ref_high, p="fro")
198
+ mag_loss_high = l1Loss(torch.log(mag_ref_high), torch.log(mag_deg_high))
199
+
200
+ mstft_sc_high += sc_loss_high
201
+ mstft_mag_high += mag_loss_high
202
+
203
+ # Normalize distances
204
+ mstft_sc /= len(fft_sizes)
205
+ mstft_sc_low /= len(fft_sizes)
206
+ mstft_sc_mid /= len(fft_sizes)
207
+ mstft_sc_high /= len(fft_sizes)
208
+
209
+ mstft_mag /= len(fft_sizes)
210
+ mstft_mag_low /= len(fft_sizes)
211
+ mstft_mag_mid /= len(fft_sizes)
212
+ mstft_mag_high /= len(fft_sizes)
213
+
214
+ # return (
215
+ # mstft_sc.numpy().tolist(),
216
+ # mstft_sc_low.numpy().tolist(),
217
+ # mstft_sc_mid.numpy().tolist(),
218
+ # mstft_sc_high.numpy().tolist(),
219
+ # mstft_mag.numpy().tolist(),
220
+ # mstft_mag_low.numpy().tolist(),
221
+ # mstft_mag_mid.numpy().tolist(),
222
+ # mstft_mag_high.numpy().tolist(),
223
+ # )
224
+
225
+ return mstft_sc.numpy().tolist() + mstft_mag.numpy().tolist()
evaluation/metrics/spectrogram/pesq.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import librosa
7
+
8
+ import numpy as np
9
+
10
+ from pypesq import pesq
11
+
12
+
13
+ def extract_pesq(audio_ref, audio_deg, fs=None, method="cut"):
14
+ """Extract PESQ for a two given audio.
15
+ audio1: the given reference audio. It is a numpy array.
16
+ audio2: the given synthesized audio. It is a numpy array.
17
+ fs: sampling rate.
18
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
19
+ "cut" will cut both audios into a same length according to the one with the shorter length.
20
+ """
21
+ # Load audio
22
+ if fs != None:
23
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
24
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
25
+ else:
26
+ audio_ref, fs = librosa.load(audio_ref)
27
+ audio_deg, fs = librosa.load(audio_deg)
28
+
29
+ # Resample
30
+ if fs != 16000:
31
+ audio_ref = librosa.resample(audio_ref, orig_sr=fs, target_sr=16000)
32
+ audio_deg = librosa.resample(audio_deg, orig_sr=fs, target_sr=16000)
33
+ fs = 16000
34
+
35
+ # Audio length alignment
36
+ if len(audio_ref) != len(audio_deg):
37
+ if method == "cut":
38
+ length = min(len(audio_ref), len(audio_deg))
39
+ audio_ref = audio_ref[:length]
40
+ audio_deg = audio_deg[:length]
41
+ elif method == "dtw":
42
+ _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
43
+ audio_ref_new = []
44
+ audio_deg_new = []
45
+ for i in range(wp.shape[0]):
46
+ ref_index = wp[i][0]
47
+ deg_index = wp[i][1]
48
+ audio_ref_new.append(audio_ref[ref_index])
49
+ audio_deg_new.append(audio_deg[deg_index])
50
+ audio_ref = np.array(audio_ref_new)
51
+ audio_deg = np.array(audio_deg_new)
52
+ assert len(audio_ref) == len(audio_deg)
53
+
54
+ # Compute pesq
55
+ score = pesq(audio_ref, audio_deg, fs)
56
+ return score
evaluation/metrics/spectrogram/scale_invariant_signal_to_distortion_ratio.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import librosa
8
+
9
+ import numpy as np
10
+
11
+ from torchmetrics import ScaleInvariantSignalDistortionRatio
12
+
13
+
14
+ def extract_si_sdr(audio_ref, audio_deg, fs=None, method="cut"):
15
+ si_sdr = ScaleInvariantSignalDistortionRatio()
16
+
17
+ if fs != None:
18
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
19
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
20
+ else:
21
+ audio_ref, fs = librosa.load(audio_ref)
22
+ audio_deg, fs = librosa.load(audio_deg)
23
+
24
+ if len(audio_ref) != len(audio_deg):
25
+ if method == "cut":
26
+ length = min(len(audio_ref), len(audio_deg))
27
+ audio_ref = audio_ref[:length]
28
+ audio_deg = audio_deg[:length]
29
+ elif method == "dtw":
30
+ _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
31
+ audio_ref_new = []
32
+ audio_deg_new = []
33
+ for i in range(wp.shape[0]):
34
+ ref_index = wp[i][0]
35
+ deg_index = wp[i][1]
36
+ audio_ref_new.append(audio_ref[ref_index])
37
+ audio_deg_new.append(audio_deg[deg_index])
38
+ audio_ref = np.array(audio_ref_new)
39
+ audio_deg = np.array(audio_deg_new)
40
+ assert len(audio_ref) == len(audio_deg)
41
+
42
+ audio_ref = torch.from_numpy(audio_ref)
43
+ audio_deg = torch.from_numpy(audio_deg)
44
+
45
+ return si_sdr(audio_deg, audio_ref)
evaluation/metrics/spectrogram/scale_invariant_signal_to_noise_ratio.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import librosa
8
+
9
+ import numpy as np
10
+
11
+ from torchmetrics import ScaleInvariantSignalNoiseRatio
12
+
13
+
14
+ def extract_si_snr(audio_ref, audio_deg, fs=None, method="cut"):
15
+ si_snr = ScaleInvariantSignalNoiseRatio()
16
+
17
+ if fs != None:
18
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
19
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
20
+ else:
21
+ audio_ref, fs = librosa.load(audio_ref)
22
+ audio_deg, fs = librosa.load(audio_deg)
23
+
24
+ if len(audio_ref) != len(audio_deg):
25
+ if method == "cut":
26
+ length = min(len(audio_ref), len(audio_deg))
27
+ audio_ref = audio_ref[:length]
28
+ audio_deg = audio_deg[:length]
29
+ elif method == "dtw":
30
+ _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
31
+ audio_ref_new = []
32
+ audio_deg_new = []
33
+ for i in range(wp.shape[0]):
34
+ ref_index = wp[i][0]
35
+ deg_index = wp[i][1]
36
+ audio_ref_new.append(audio_ref[ref_index])
37
+ audio_deg_new.append(audio_deg[deg_index])
38
+ audio_ref = np.array(audio_ref_new)
39
+ audio_deg = np.array(audio_deg_new)
40
+ assert len(audio_ref) == len(audio_deg)
41
+
42
+ audio_ref = torch.from_numpy(audio_ref)
43
+ audio_deg = torch.from_numpy(audio_deg)
44
+
45
+ return si_snr(audio_deg, audio_ref)
evaluation/metrics/spectrogram/short_time_objective_intelligibility.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import librosa
8
+
9
+ import numpy as np
10
+
11
+ from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
12
+
13
+
14
+ def extract_stoi(audio_ref, audio_deg, fs=None, extended=False, method="cut"):
15
+ """Compute Short-Time Objective Intelligibility between the predicted and the ground truth audio.
16
+ audio_ref: path to the ground truth audio.
17
+ audio_deg: path to the predicted audio.
18
+ fs: sampling rate.
19
+ method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
20
+ "cut" will cut both audios into a same length according to the one with the shorter length.
21
+ """
22
+ # Load audio
23
+ if fs != None:
24
+ audio_ref, _ = librosa.load(audio_ref, sr=fs)
25
+ audio_deg, _ = librosa.load(audio_deg, sr=fs)
26
+ else:
27
+ audio_ref, fs = librosa.load(audio_ref)
28
+ audio_deg, fs = librosa.load(audio_deg)
29
+
30
+ # Initialize method
31
+ stoi = ShortTimeObjectiveIntelligibility(fs, extended)
32
+
33
+ # Audio length alignment
34
+ if len(audio_ref) != len(audio_deg):
35
+ if method == "cut":
36
+ length = min(len(audio_ref), len(audio_deg))
37
+ audio_ref = audio_ref[:length]
38
+ audio_deg = audio_deg[:length]
39
+ elif method == "dtw":
40
+ _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
41
+ audio_ref_new = []
42
+ audio_deg_new = []
43
+ for i in range(wp.shape[0]):
44
+ ref_index = wp[i][0]
45
+ deg_index = wp[i][1]
46
+ audio_ref_new.append(audio_ref[ref_index])
47
+ audio_deg_new.append(audio_deg[deg_index])
48
+ audio_ref = np.array(audio_ref_new)
49
+ audio_deg = np.array(audio_deg_new)
50
+ assert len(audio_ref) == len(audio_deg)
51
+
52
+ # Convert to tensor
53
+ audio_ref = torch.from_numpy(audio_ref)
54
+ audio_deg = torch.from_numpy(audio_deg)
55
+
56
+ return stoi(audio_deg, audio_ref).numpy().tolist()
models/tts/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # from .tts_inferece import TTSInference
7
+ from .tts_trainer import TTSTrainer
models/tts/base/tts_dataset.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import torchaudio
9
+ import numpy as np
10
+ import torch
11
+ from utils.data_utils import *
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from text import text_to_sequence
14
+ from text.text_token_collation import phoneIDCollation
15
+ from processors.acoustic_extractor import cal_normalized_mel
16
+
17
+ from models.base.base_dataset import (
18
+ BaseDataset,
19
+ BaseCollator,
20
+ BaseTestDataset,
21
+ BaseTestCollator,
22
+ )
23
+
24
+ from processors.content_extractor import (
25
+ ContentvecExtractor,
26
+ WenetExtractor,
27
+ WhisperExtractor,
28
+ )
29
+
30
+
31
+ class TTSDataset(BaseDataset):
32
+ def __init__(self, cfg, dataset, is_valid=False):
33
+ """
34
+ Args:
35
+ cfg: config
36
+ dataset: dataset name
37
+ is_valid: whether to use train or valid dataset
38
+ """
39
+
40
+ assert isinstance(dataset, str)
41
+
42
+ self.cfg = cfg
43
+
44
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
45
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
46
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
47
+ self.metadata = self.get_metadata()
48
+
49
+ """
50
+ load spk2id and utt2spk from json file
51
+ spk2id: {spk1: 0, spk2: 1, ...}
52
+ utt2spk: {dataset_uid: spk1, ...}
53
+ """
54
+ if cfg.preprocess.use_spkid:
55
+ dataset = self.metadata[0]["Dataset"]
56
+
57
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
58
+ with open(spk2id_path, "r") as f:
59
+ self.spk2id = json.load(f)
60
+
61
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
62
+ self.utt2spk = dict()
63
+ with open(utt2spk_path, "r") as f:
64
+ for line in f.readlines():
65
+ utt, spk = line.strip().split("\t")
66
+ self.utt2spk[utt] = spk
67
+
68
+ if cfg.preprocess.use_uv:
69
+ self.utt2uv_path = {}
70
+ for utt_info in self.metadata:
71
+ dataset = utt_info["Dataset"]
72
+ uid = utt_info["Uid"]
73
+ utt = "{}_{}".format(dataset, uid)
74
+ self.utt2uv_path[utt] = os.path.join(
75
+ cfg.preprocess.processed_dir,
76
+ dataset,
77
+ cfg.preprocess.uv_dir,
78
+ uid + ".npy",
79
+ )
80
+
81
+ if cfg.preprocess.use_frame_pitch:
82
+ self.utt2frame_pitch_path = {}
83
+ for utt_info in self.metadata:
84
+ dataset = utt_info["Dataset"]
85
+ uid = utt_info["Uid"]
86
+ utt = "{}_{}".format(dataset, uid)
87
+
88
+ self.utt2frame_pitch_path[utt] = os.path.join(
89
+ cfg.preprocess.processed_dir,
90
+ dataset,
91
+ cfg.preprocess.pitch_dir,
92
+ uid + ".npy",
93
+ )
94
+
95
+ if cfg.preprocess.use_frame_energy:
96
+ self.utt2frame_energy_path = {}
97
+ for utt_info in self.metadata:
98
+ dataset = utt_info["Dataset"]
99
+ uid = utt_info["Uid"]
100
+ utt = "{}_{}".format(dataset, uid)
101
+
102
+ self.utt2frame_energy_path[utt] = os.path.join(
103
+ cfg.preprocess.processed_dir,
104
+ dataset,
105
+ cfg.preprocess.energy_dir,
106
+ uid + ".npy",
107
+ )
108
+
109
+ if cfg.preprocess.use_mel:
110
+ self.utt2mel_path = {}
111
+ for utt_info in self.metadata:
112
+ dataset = utt_info["Dataset"]
113
+ uid = utt_info["Uid"]
114
+ utt = "{}_{}".format(dataset, uid)
115
+
116
+ self.utt2mel_path[utt] = os.path.join(
117
+ cfg.preprocess.processed_dir,
118
+ dataset,
119
+ cfg.preprocess.mel_dir,
120
+ uid + ".npy",
121
+ )
122
+
123
+ if cfg.preprocess.use_linear:
124
+ self.utt2linear_path = {}
125
+ for utt_info in self.metadata:
126
+ dataset = utt_info["Dataset"]
127
+ uid = utt_info["Uid"]
128
+ utt = "{}_{}".format(dataset, uid)
129
+
130
+ self.utt2linear_path[utt] = os.path.join(
131
+ cfg.preprocess.processed_dir,
132
+ dataset,
133
+ cfg.preprocess.linear_dir,
134
+ uid + ".npy",
135
+ )
136
+
137
+ if cfg.preprocess.use_audio:
138
+ self.utt2audio_path = {}
139
+ for utt_info in self.metadata:
140
+ dataset = utt_info["Dataset"]
141
+ uid = utt_info["Uid"]
142
+ utt = "{}_{}".format(dataset, uid)
143
+
144
+ if cfg.preprocess.extract_audio:
145
+ self.utt2audio_path[utt] = os.path.join(
146
+ cfg.preprocess.processed_dir,
147
+ dataset,
148
+ cfg.preprocess.audio_dir,
149
+ uid + ".wav",
150
+ )
151
+ else:
152
+ self.utt2audio_path[utt] = utt_info["Path"]
153
+
154
+ # self.utt2audio_path[utt] = os.path.join(
155
+ # cfg.preprocess.processed_dir,
156
+ # dataset,
157
+ # cfg.preprocess.audio_dir,
158
+ # uid + ".numpy",
159
+ # )
160
+
161
+ elif cfg.preprocess.use_label:
162
+ self.utt2label_path = {}
163
+ for utt_info in self.metadata:
164
+ dataset = utt_info["Dataset"]
165
+ uid = utt_info["Uid"]
166
+ utt = "{}_{}".format(dataset, uid)
167
+
168
+ self.utt2label_path[utt] = os.path.join(
169
+ cfg.preprocess.processed_dir,
170
+ dataset,
171
+ cfg.preprocess.label_dir,
172
+ uid + ".npy",
173
+ )
174
+ elif cfg.preprocess.use_one_hot:
175
+ self.utt2one_hot_path = {}
176
+ for utt_info in self.metadata:
177
+ dataset = utt_info["Dataset"]
178
+ uid = utt_info["Uid"]
179
+ utt = "{}_{}".format(dataset, uid)
180
+
181
+ self.utt2one_hot_path[utt] = os.path.join(
182
+ cfg.preprocess.processed_dir,
183
+ dataset,
184
+ cfg.preprocess.one_hot_dir,
185
+ uid + ".npy",
186
+ )
187
+
188
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
189
+ self.utt2seq = {}
190
+ for utt_info in self.metadata:
191
+ dataset = utt_info["Dataset"]
192
+ uid = utt_info["Uid"]
193
+ utt = "{}_{}".format(dataset, uid)
194
+
195
+ if cfg.preprocess.use_text:
196
+ text = utt_info["Text"]
197
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
198
+ elif cfg.preprocess.use_phone:
199
+ # load phoneme squence from phone file
200
+ phone_path = os.path.join(
201
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
202
+ )
203
+ with open(phone_path, "r") as fin:
204
+ phones = fin.readlines()
205
+ assert len(phones) == 1
206
+ phones = phones[0].strip()
207
+ phones_seq = phones.split(" ")
208
+
209
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
210
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
211
+
212
+ self.utt2seq[utt] = sequence
213
+
214
+ def __getitem__(self, index):
215
+ utt_info = self.metadata[index]
216
+
217
+ dataset = utt_info["Dataset"]
218
+ uid = utt_info["Uid"]
219
+ utt = "{}_{}".format(dataset, uid)
220
+
221
+ single_feature = dict()
222
+
223
+ if self.cfg.preprocess.use_spkid:
224
+ single_feature["spk_id"] = np.array(
225
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
226
+ )
227
+
228
+ if self.cfg.preprocess.use_mel:
229
+ mel = np.load(self.utt2mel_path[utt])
230
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
231
+ if self.cfg.preprocess.use_min_max_norm_mel:
232
+ # do mel norm
233
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
234
+
235
+ if "target_len" not in single_feature.keys():
236
+ single_feature["target_len"] = mel.shape[1]
237
+ single_feature["mel"] = mel.T # [T, n_mels]
238
+
239
+ if self.cfg.preprocess.use_linear:
240
+ linear = np.load(self.utt2linear_path[utt])
241
+ if "target_len" not in single_feature.keys():
242
+ single_feature["target_len"] = linear.shape[1]
243
+ single_feature["linear"] = linear.T # [T, n_linear]
244
+
245
+ if self.cfg.preprocess.use_frame_pitch:
246
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
247
+ frame_pitch = np.load(frame_pitch_path)
248
+ if "target_len" not in single_feature.keys():
249
+ single_feature["target_len"] = len(frame_pitch)
250
+ aligned_frame_pitch = align_length(
251
+ frame_pitch, single_feature["target_len"]
252
+ )
253
+ single_feature["frame_pitch"] = aligned_frame_pitch
254
+
255
+ if self.cfg.preprocess.use_uv:
256
+ frame_uv_path = self.utt2uv_path[utt]
257
+ frame_uv = np.load(frame_uv_path)
258
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
259
+ aligned_frame_uv = [
260
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
261
+ ]
262
+ aligned_frame_uv = np.array(aligned_frame_uv)
263
+ single_feature["frame_uv"] = aligned_frame_uv
264
+
265
+ if self.cfg.preprocess.use_frame_energy:
266
+ frame_energy_path = self.utt2frame_energy_path[utt]
267
+ frame_energy = np.load(frame_energy_path)
268
+ if "target_len" not in single_feature.keys():
269
+ single_feature["target_len"] = len(frame_energy)
270
+ aligned_frame_energy = align_length(
271
+ frame_energy, single_feature["target_len"]
272
+ )
273
+ single_feature["frame_energy"] = aligned_frame_energy
274
+
275
+ if self.cfg.preprocess.use_audio:
276
+ audio, sr = torchaudio.load(self.utt2audio_path[utt])
277
+ audio = audio.cpu().numpy().squeeze()
278
+ single_feature["audio"] = audio
279
+ single_feature["audio_len"] = audio.shape[0]
280
+
281
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
282
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
283
+ single_feature["phone_len"] = len(self.utt2seq[utt])
284
+
285
+ return single_feature
286
+
287
+ def __len__(self):
288
+ return super().__len__()
289
+
290
+ def get_metadata(self):
291
+ return super().get_metadata()
292
+
293
+
294
+ class TTSCollator(BaseCollator):
295
+ """Zero-pads model inputs and targets based on number of frames per step"""
296
+
297
+ def __init__(self, cfg):
298
+ super().__init__(cfg)
299
+
300
+ def __call__(self, batch):
301
+ parsed_batch_features = super().__call__(batch)
302
+ return parsed_batch_features
303
+
304
+
305
+ class TTSTestDataset(BaseTestDataset):
306
+ def __init__(self, args, cfg):
307
+ self.cfg = cfg
308
+
309
+ # inference from test list file
310
+ if args.test_list_file is not None:
311
+ # construst metadata
312
+ self.metadata = []
313
+
314
+ with open(args.test_list_file, "r") as fin:
315
+ for idx, line in enumerate(fin.readlines()):
316
+ utt_info = {}
317
+
318
+ utt_info["Dataset"] = "test"
319
+ utt_info["Text"] = line.strip()
320
+ utt_info["Uid"] = str(idx)
321
+ self.metadata.append(utt_info)
322
+
323
+ else:
324
+ assert args.testing_set
325
+ self.metafile_path = os.path.join(
326
+ cfg.preprocess.processed_dir,
327
+ args.dataset,
328
+ "{}.json".format(args.testing_set),
329
+ )
330
+ self.metadata = self.get_metadata()
331
+
332
+ def __getitem__(self, index):
333
+ single_feature = {}
334
+
335
+ return single_feature
336
+
337
+ def __len__(self):
338
+ return len(self.metadata)
339
+
340
+
341
+ class TTSTestCollator(BaseTestCollator):
342
+ """Zero-pads model inputs and targets based on number of frames per step"""
343
+
344
+ def __init__(self, cfg):
345
+ self.cfg = cfg
346
+
347
+ def __call__(self, batch):
348
+ packed_batch_features = dict()
349
+
350
+ # mel: [b, T, n_mels]
351
+ # frame_pitch, frame_energy: [1, T]
352
+ # target_len: [1]
353
+ # spk_id: [b, 1]
354
+ # mask: [b, T, 1]
355
+
356
+ for key in batch[0].keys():
357
+ if key == "target_len":
358
+ packed_batch_features["target_len"] = torch.LongTensor(
359
+ [b["target_len"] for b in batch]
360
+ )
361
+ masks = [
362
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
363
+ ]
364
+ packed_batch_features["mask"] = pad_sequence(
365
+ masks, batch_first=True, padding_value=0
366
+ )
367
+ elif key == "phone_len":
368
+ packed_batch_features["phone_len"] = torch.LongTensor(
369
+ [b["phone_len"] for b in batch]
370
+ )
371
+ masks = [
372
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
373
+ ]
374
+ packed_batch_features["phn_mask"] = pad_sequence(
375
+ masks, batch_first=True, padding_value=0
376
+ )
377
+ elif key == "audio_len":
378
+ packed_batch_features["audio_len"] = torch.LongTensor(
379
+ [b["audio_len"] for b in batch]
380
+ )
381
+ masks = [
382
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
383
+ ]
384
+ else:
385
+ values = [torch.from_numpy(b[key]) for b in batch]
386
+ packed_batch_features[key] = pad_sequence(
387
+ values, batch_first=True, padding_value=0
388
+ )
389
+ return packed_batch_features
models/tts/base/tts_inferece.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ import time
9
+ import accelerate
10
+ import random
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from accelerate.logging import get_logger
14
+ from torch.utils.data import DataLoader
15
+
16
+
17
+ from abc import abstractmethod
18
+ from pathlib import Path
19
+ from utils.io import save_audio
20
+ from utils.util import load_config
21
+ from models.vocoders.vocoder_inference import synthesis
22
+
23
+
24
+ class TTSInference(object):
25
+ def __init__(self, args=None, cfg=None):
26
+ super().__init__()
27
+
28
+ start = time.monotonic_ns()
29
+ self.args = args
30
+ self.cfg = cfg
31
+ self.infer_type = args.mode
32
+
33
+ # get exp_dir
34
+ if self.args.acoustics_dir is not None:
35
+ self.exp_dir = self.args.acoustics_dir
36
+ elif self.args.checkpoint_path is not None:
37
+ self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path))
38
+
39
+ # Init accelerator
40
+ self.accelerator = accelerate.Accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+ self.device = self.accelerator.device
43
+
44
+ # Get logger
45
+ with self.accelerator.main_process_first():
46
+ self.logger = get_logger("inference", log_level=args.log_level)
47
+
48
+ # Log some info
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
51
+ self.logger.info("=" * 56)
52
+ self.logger.info("\n")
53
+
54
+ self.acoustic_model_dir = args.acoustics_dir
55
+ self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}")
56
+
57
+ if args.vocoder_dir is not None:
58
+ self.vocoder_dir = args.vocoder_dir
59
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
60
+
61
+ os.makedirs(args.output_dir, exist_ok=True)
62
+
63
+ # Set random seed
64
+ with self.accelerator.main_process_first():
65
+ start = time.monotonic_ns()
66
+ self._set_random_seed(self.cfg.train.random_seed)
67
+ end = time.monotonic_ns()
68
+ self.logger.debug(
69
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
70
+ )
71
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
72
+
73
+ # Setup data loader
74
+ if self.infer_type == "batch":
75
+ with self.accelerator.main_process_first():
76
+ self.logger.info("Building dataset...")
77
+ start = time.monotonic_ns()
78
+ self.test_dataloader = self._build_test_dataloader()
79
+ end = time.monotonic_ns()
80
+ self.logger.info(
81
+ f"Building dataset done in {(end - start) / 1e6:.2f}ms"
82
+ )
83
+
84
+ # Build model
85
+ with self.accelerator.main_process_first():
86
+ self.logger.info("Building model...")
87
+ start = time.monotonic_ns()
88
+ self.model = self._build_model()
89
+ end = time.monotonic_ns()
90
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
91
+
92
+ # Init with accelerate
93
+ self.logger.info("Initializing accelerate...")
94
+ start = time.monotonic_ns()
95
+ self.accelerator = accelerate.Accelerator()
96
+ self.model = self.accelerator.prepare(self.model)
97
+ if self.infer_type == "batch":
98
+ self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
99
+ end = time.monotonic_ns()
100
+ self.accelerator.wait_for_everyone()
101
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
102
+
103
+ with self.accelerator.main_process_first():
104
+ self.logger.info("Loading checkpoint...")
105
+ start = time.monotonic_ns()
106
+ if args.acoustics_dir is not None:
107
+ self._load_model(
108
+ checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint")
109
+ )
110
+ elif args.checkpoint_path is not None:
111
+ self._load_model(checkpoint_path=args.checkpoint_path)
112
+ else:
113
+ print("Either checkpoint dir or checkpoint path should be provided.")
114
+
115
+ end = time.monotonic_ns()
116
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
117
+
118
+ self.model.eval()
119
+ self.accelerator.wait_for_everyone()
120
+
121
+ def _build_test_dataset(self):
122
+ pass
123
+
124
+ def _build_model(self):
125
+ pass
126
+
127
+ # TODO: LEGACY CODE
128
+ def _build_test_dataloader(self):
129
+ datasets, collate = self._build_test_dataset()
130
+ self.test_dataset = datasets(self.args, self.cfg)
131
+ self.test_collate = collate(self.cfg)
132
+ self.test_batch_size = min(
133
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
134
+ )
135
+ test_dataloader = DataLoader(
136
+ self.test_dataset,
137
+ collate_fn=self.test_collate,
138
+ num_workers=1,
139
+ batch_size=self.test_batch_size,
140
+ shuffle=False,
141
+ )
142
+ return test_dataloader
143
+
144
+ def _load_model(
145
+ self,
146
+ checkpoint_dir: str = None,
147
+ checkpoint_path: str = None,
148
+ old_mode: bool = False,
149
+ ):
150
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
151
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
152
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
153
+ method after** ``accelerator.prepare()``.
154
+ """
155
+
156
+ if checkpoint_path is None:
157
+ assert checkpoint_dir is not None
158
+ # Load the latest accelerator state dicts
159
+ ls = [
160
+ str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i)
161
+ ]
162
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
163
+ checkpoint_path = ls[0]
164
+
165
+ self.accelerator.load_state(str(checkpoint_path))
166
+ return str(checkpoint_path)
167
+
168
+ def inference(self):
169
+ if self.infer_type == "single":
170
+ out_dir = os.path.join(self.args.output_dir, "single")
171
+ os.makedirs(out_dir, exist_ok=True)
172
+
173
+ pred_audio = self.inference_for_single_utterance()
174
+ save_path = os.path.join(out_dir, "test_pred.wav")
175
+ save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate)
176
+
177
+ elif self.infer_type == "batch":
178
+ out_dir = os.path.join(self.args.output_dir, "batch")
179
+ os.makedirs(out_dir, exist_ok=True)
180
+
181
+ pred_audio_list = self.inference_for_batches()
182
+ for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
183
+ uid = it["Uid"]
184
+ save_audio(
185
+ os.path.join(out_dir, f"{uid}.wav"),
186
+ wav.numpy(),
187
+ self.cfg.preprocess.sample_rate,
188
+ add_silence=True,
189
+ turn_up=True,
190
+ )
191
+ tmp_file = os.path.join(out_dir, f"{uid}.pt")
192
+ if os.path.exists(tmp_file):
193
+ os.remove(tmp_file)
194
+ print("Saved to: ", out_dir)
195
+
196
+ @torch.inference_mode()
197
+ def inference_for_batches(self):
198
+ y_pred = []
199
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
200
+ y_pred, mel_lens, _ = self._inference_each_batch(batch)
201
+ y_ls = y_pred.chunk(self.test_batch_size)
202
+ tgt_ls = mel_lens.chunk(self.test_batch_size)
203
+ j = 0
204
+ for it, l in zip(y_ls, tgt_ls):
205
+ l = l.item()
206
+ it = it.squeeze(0)[:l].detach().cpu()
207
+
208
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
209
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
210
+ j += 1
211
+
212
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
213
+ res = synthesis(
214
+ cfg=vocoder_cfg,
215
+ vocoder_weight_file=vocoder_ckpt,
216
+ n_samples=None,
217
+ pred=[
218
+ torch.load(
219
+ os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
220
+ ).numpy()
221
+ for item in self.test_dataset.metadata
222
+ ],
223
+ )
224
+ for it, wav in zip(self.test_dataset.metadata, res):
225
+ uid = it["Uid"]
226
+ save_audio(
227
+ os.path.join(self.args.output_dir, f"{uid}.wav"),
228
+ wav.numpy(),
229
+ 22050,
230
+ add_silence=True,
231
+ turn_up=True,
232
+ )
233
+
234
+ @abstractmethod
235
+ @torch.inference_mode()
236
+ def _inference_each_batch(self, batch_data):
237
+ pass
238
+
239
+ def inference_for_single_utterance(self, text):
240
+ pass
241
+
242
+ def synthesis_by_vocoder(self, pred):
243
+ audios_pred = synthesis(
244
+ self.vocoder_cfg,
245
+ self.checkpoint_dir_vocoder,
246
+ len(pred),
247
+ pred,
248
+ )
249
+
250
+ return audios_pred
251
+
252
+ @staticmethod
253
+ def _parse_vocoder(vocoder_dir):
254
+ r"""Parse vocoder config"""
255
+ vocoder_dir = os.path.abspath(vocoder_dir)
256
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
257
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
258
+ ckpt_path = str(ckpt_list[0])
259
+ vocoder_cfg = load_config(
260
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
261
+ )
262
+ return vocoder_cfg, ckpt_path
263
+
264
+ def _set_random_seed(self, seed):
265
+ """Set random seed for all possible random modules."""
266
+ random.seed(seed)
267
+ np.random.seed(seed)
268
+ torch.random.manual_seed(seed)
models/tts/base/tts_trainer.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import shutil
9
+ import torch
10
+ import time
11
+ from pathlib import Path
12
+ import torch
13
+ from tqdm import tqdm
14
+ import re
15
+ import logging
16
+ import json5
17
+ import accelerate
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from accelerate import DistributedDataParallelKwargs
22
+ from schedulers.scheduler import Eden
23
+ from models.base.base_sampler import build_samplers
24
+ from models.base.new_trainer import BaseTrainer
25
+
26
+
27
+ class TTSTrainer(BaseTrainer):
28
+ r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
29
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
30
+ class, and implement ``_build_model``, ``_forward_step``.
31
+ """
32
+
33
+ def __init__(self, args=None, cfg=None):
34
+ self.args = args
35
+ self.cfg = cfg
36
+
37
+ cfg.exp_name = args.exp_name
38
+
39
+ # init with accelerate
40
+ self._init_accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level="INFO")
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # save phone table to exp dir. Should be done before building model due to loading phone table in model
104
+ if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
105
+ self._save_phone_symbols_file_to_exp_path()
106
+
107
+ # setup model
108
+ with self.accelerator.main_process_first():
109
+ self.logger.info("Building model...")
110
+ start = time.monotonic_ns()
111
+ self.model = self._build_model()
112
+ end = time.monotonic_ns()
113
+ self.logger.debug(self.model)
114
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
115
+ self.logger.info(
116
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
117
+ )
118
+
119
+ # optimizer & scheduler
120
+ with self.accelerator.main_process_first():
121
+ self.logger.info("Building optimizer and scheduler...")
122
+ start = time.monotonic_ns()
123
+ self.optimizer = self._build_optimizer()
124
+ self.scheduler = self._build_scheduler()
125
+ end = time.monotonic_ns()
126
+ self.logger.info(
127
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
128
+ )
129
+
130
+ # create criterion
131
+ with self.accelerator.main_process_first():
132
+ self.logger.info("Building criterion...")
133
+ start = time.monotonic_ns()
134
+ self.criterion = self._build_criterion()
135
+ end = time.monotonic_ns()
136
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
137
+
138
+ # Resume or Finetune
139
+ with self.accelerator.main_process_first():
140
+ self._check_resume()
141
+
142
+ # accelerate prepare
143
+ self.logger.info("Initializing accelerate...")
144
+ start = time.monotonic_ns()
145
+ self._accelerator_prepare()
146
+ end = time.monotonic_ns()
147
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
148
+
149
+ # save config file path
150
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
151
+ self.device = self.accelerator.device
152
+
153
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
154
+ self.speakers = self._build_speaker_lut()
155
+ self.utt2spk_dict = self._build_utt2spk_dict()
156
+
157
+ # Only for TTS tasks
158
+ self.task_type = "TTS"
159
+ self.logger.info("Task type: {}".format(self.task_type))
160
+
161
+ def _check_resume(self):
162
+ # if args.resume:
163
+ if self.args.resume or (
164
+ self.cfg.model_type == "VALLE" and self.args.train_stage == 2
165
+ ):
166
+ if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
167
+ self.args.resume_type = "finetune"
168
+
169
+ self.logger.info("Resuming from checkpoint...")
170
+ start = time.monotonic_ns()
171
+ self.ckpt_path = self._load_model(
172
+ self.checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
173
+ )
174
+ end = time.monotonic_ns()
175
+ self.logger.info(
176
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
177
+ )
178
+ self.checkpoints_path = json.load(
179
+ open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
180
+ )
181
+
182
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
183
+ if self.accelerator.is_main_process:
184
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
185
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
186
+
187
+ def _init_accelerator(self):
188
+ self.exp_dir = os.path.join(
189
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
190
+ )
191
+ project_config = ProjectConfiguration(
192
+ project_dir=self.exp_dir,
193
+ logging_dir=os.path.join(self.exp_dir, "log"),
194
+ )
195
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
196
+ self.accelerator = accelerate.Accelerator(
197
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
198
+ log_with=self.cfg.train.tracker,
199
+ project_config=project_config,
200
+ kwargs_handlers=[kwargs],
201
+ )
202
+ if self.accelerator.is_main_process:
203
+ os.makedirs(project_config.project_dir, exist_ok=True)
204
+ os.makedirs(project_config.logging_dir, exist_ok=True)
205
+ with self.accelerator.main_process_first():
206
+ self.accelerator.init_trackers(self.args.exp_name)
207
+
208
+ def _accelerator_prepare(self):
209
+ (
210
+ self.train_dataloader,
211
+ self.valid_dataloader,
212
+ ) = self.accelerator.prepare(
213
+ self.train_dataloader,
214
+ self.valid_dataloader,
215
+ )
216
+
217
+ if isinstance(self.model, dict):
218
+ for key in self.model.keys():
219
+ self.model[key] = self.accelerator.prepare(self.model[key])
220
+ else:
221
+ self.model = self.accelerator.prepare(self.model)
222
+
223
+ if isinstance(self.optimizer, dict):
224
+ for key in self.optimizer.keys():
225
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
226
+ else:
227
+ self.optimizer = self.accelerator.prepare(self.optimizer)
228
+
229
+ if isinstance(self.scheduler, dict):
230
+ for key in self.scheduler.keys():
231
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
232
+ else:
233
+ self.scheduler = self.accelerator.prepare(self.scheduler)
234
+
235
+ ### Following are methods only for TTS tasks ###
236
+ def _build_dataset(self):
237
+ pass
238
+
239
+ def _build_criterion(self):
240
+ pass
241
+
242
+ def _build_model(self):
243
+ pass
244
+
245
+ def _build_dataloader(self):
246
+ """Build dataloader which merges a series of datasets."""
247
+ # Build dataset instance for each dataset and combine them by ConcatDataset
248
+ Dataset, Collator = self._build_dataset()
249
+
250
+ # Build train set
251
+ datasets_list = []
252
+ for dataset in self.cfg.dataset:
253
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
254
+ datasets_list.append(subdataset)
255
+ train_dataset = ConcatDataset(datasets_list)
256
+ train_collate = Collator(self.cfg)
257
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
258
+ train_loader = DataLoader(
259
+ train_dataset,
260
+ collate_fn=train_collate,
261
+ batch_sampler=batch_sampler,
262
+ num_workers=self.cfg.train.dataloader.num_worker,
263
+ pin_memory=self.cfg.train.dataloader.pin_memory,
264
+ )
265
+
266
+ # Build test set
267
+ datasets_list = []
268
+ for dataset in self.cfg.dataset:
269
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
270
+ datasets_list.append(subdataset)
271
+ valid_dataset = ConcatDataset(datasets_list)
272
+ valid_collate = Collator(self.cfg)
273
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
274
+ valid_loader = DataLoader(
275
+ valid_dataset,
276
+ collate_fn=valid_collate,
277
+ batch_sampler=batch_sampler,
278
+ num_workers=self.cfg.train.dataloader.num_worker,
279
+ pin_memory=self.cfg.train.dataloader.pin_memory,
280
+ )
281
+ return train_loader, valid_loader
282
+
283
+ def _build_optimizer(self):
284
+ pass
285
+
286
+ def _build_scheduler(self):
287
+ pass
288
+
289
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
290
+ """Load model from checkpoint. If a folder is given, it will
291
+ load the latest checkpoint in checkpoint_dir. If a path is given
292
+ it will load the checkpoint specified by checkpoint_path.
293
+ **Only use this method after** ``accelerator.prepare()``.
294
+ """
295
+ if checkpoint_path is None:
296
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
297
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
298
+ checkpoint_path = ls[0]
299
+ self.logger.info("Load model from {}".format(checkpoint_path))
300
+ print("Load model from {}".format(checkpoint_path))
301
+ if resume_type == "resume":
302
+ self.accelerator.load_state(checkpoint_path)
303
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
304
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
305
+ elif resume_type == "finetune":
306
+ self.model.load_state_dict(
307
+ torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
308
+ )
309
+ self.model.cuda(self.accelerator.device)
310
+ self.logger.info("Load model weights for finetune SUCCESS!")
311
+ else:
312
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
313
+
314
+ return checkpoint_path
315
+
316
+ ### THIS IS MAIN ENTRY ###
317
+ def train_loop(self):
318
+ r"""Training loop. The public entry of training process."""
319
+ # Wait everyone to prepare before we move on
320
+ self.accelerator.wait_for_everyone()
321
+ # dump config file
322
+ if self.accelerator.is_main_process:
323
+ self.__dump_cfg(self.config_save_path)
324
+
325
+ # self.optimizer.zero_grad()
326
+ # Wait to ensure good to go
327
+
328
+ self.accelerator.wait_for_everyone()
329
+ while self.epoch < self.max_epoch:
330
+ self.logger.info("\n")
331
+ self.logger.info("-" * 32)
332
+ self.logger.info("Epoch {}: ".format(self.epoch))
333
+
334
+ # Do training & validating epoch
335
+ train_total_loss, train_losses = self._train_epoch()
336
+ if isinstance(train_losses, dict):
337
+ for key, loss in train_losses.items():
338
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
339
+ self.accelerator.log(
340
+ {"Epoch/Train {} Loss".format(key): loss},
341
+ step=self.epoch,
342
+ )
343
+
344
+ valid_total_loss, valid_losses = self._valid_epoch()
345
+ if isinstance(valid_losses, dict):
346
+ for key, loss in valid_losses.items():
347
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
348
+ self.accelerator.log(
349
+ {"Epoch/Train {} Loss".format(key): loss},
350
+ step=self.epoch,
351
+ )
352
+
353
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
354
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
355
+ self.accelerator.log(
356
+ {
357
+ "Epoch/Train Loss": train_total_loss,
358
+ "Epoch/Valid Loss": valid_total_loss,
359
+ },
360
+ step=self.epoch,
361
+ )
362
+
363
+ self.accelerator.wait_for_everyone()
364
+
365
+ # Check if hit save_checkpoint_stride and run_eval
366
+ run_eval = False
367
+ if self.accelerator.is_main_process:
368
+ save_checkpoint = False
369
+ hit_dix = []
370
+ for i, num in enumerate(self.save_checkpoint_stride):
371
+ if self.epoch % num == 0:
372
+ save_checkpoint = True
373
+ hit_dix.append(i)
374
+ run_eval |= self.run_eval[i]
375
+
376
+ self.accelerator.wait_for_everyone()
377
+ if self.accelerator.is_main_process and save_checkpoint:
378
+ path = os.path.join(
379
+ self.checkpoint_dir,
380
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
381
+ self.epoch, self.step, train_total_loss
382
+ ),
383
+ )
384
+ self.accelerator.save_state(path)
385
+
386
+ json.dump(
387
+ self.checkpoints_path,
388
+ open(os.path.join(path, "ckpts.json"), "w"),
389
+ ensure_ascii=False,
390
+ indent=4,
391
+ )
392
+
393
+ # Remove old checkpoints
394
+ to_remove = []
395
+ for idx in hit_dix:
396
+ self.checkpoints_path[idx].append(path)
397
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
398
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
399
+
400
+ # Search conflicts
401
+ total = set()
402
+ for i in self.checkpoints_path:
403
+ total |= set(i)
404
+ do_remove = set()
405
+ for idx, path in to_remove[::-1]:
406
+ if path in total:
407
+ self.checkpoints_path[idx].insert(0, path)
408
+ else:
409
+ do_remove.add(path)
410
+
411
+ # Remove old checkpoints
412
+ for path in do_remove:
413
+ shutil.rmtree(path, ignore_errors=True)
414
+ self.logger.debug(f"Remove old checkpoint: {path}")
415
+
416
+ self.accelerator.wait_for_everyone()
417
+ if run_eval:
418
+ # TODO: run evaluation
419
+ pass
420
+
421
+ # Update info for each epoch
422
+ self.epoch += 1
423
+
424
+ # Finish training and save final checkpoint
425
+ self.accelerator.wait_for_everyone()
426
+ if self.accelerator.is_main_process:
427
+ path = os.path.join(
428
+ self.checkpoint_dir,
429
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
430
+ self.epoch, self.step, valid_total_loss
431
+ ),
432
+ )
433
+ self.accelerator.save_state(
434
+ os.path.join(
435
+ self.checkpoint_dir,
436
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
437
+ self.epoch, self.step, valid_total_loss
438
+ ),
439
+ )
440
+ )
441
+
442
+ json.dump(
443
+ self.checkpoints_path,
444
+ open(os.path.join(path, "ckpts.json"), "w"),
445
+ ensure_ascii=False,
446
+ indent=4,
447
+ )
448
+
449
+ self.accelerator.end_training()
450
+
451
+ ### Following are methods that can be used directly in child classes ###
452
+ def _train_epoch(self):
453
+ r"""Training epoch. Should return average loss of a batch (sample) over
454
+ one epoch. See ``train_loop`` for usage.
455
+ """
456
+ if isinstance(self.model, dict):
457
+ for key in self.model.keys():
458
+ self.model[key].train()
459
+ else:
460
+ self.model.train()
461
+
462
+ epoch_sum_loss: float = 0.0
463
+ epoch_losses: dict = {}
464
+ epoch_step: int = 0
465
+ for batch in tqdm(
466
+ self.train_dataloader,
467
+ desc=f"Training Epoch {self.epoch}",
468
+ unit="batch",
469
+ colour="GREEN",
470
+ leave=False,
471
+ dynamic_ncols=True,
472
+ smoothing=0.04,
473
+ disable=not self.accelerator.is_main_process,
474
+ ):
475
+ # Do training step and BP
476
+ with self.accelerator.accumulate(self.model):
477
+ total_loss, train_losses, _ = self._train_step(batch)
478
+ self.batch_count += 1
479
+
480
+ # Update info for each step
481
+ # TODO: step means BP counts or batch counts?
482
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
483
+ if isinstance(self.scheduler, dict):
484
+ for key in self.scheduler.keys():
485
+ self.scheduler[key].step()
486
+ else:
487
+ if isinstance(self.scheduler, Eden):
488
+ self.scheduler.step_batch(self.step)
489
+ else:
490
+ self.scheduler.step()
491
+
492
+ epoch_sum_loss += total_loss
493
+
494
+ if isinstance(train_losses, dict):
495
+ for key, value in train_losses.items():
496
+ epoch_losses[key] += value
497
+
498
+ if isinstance(train_losses, dict):
499
+ for key, loss in train_losses.items():
500
+ self.accelerator.log(
501
+ {"Epoch/Train {} Loss".format(key): loss},
502
+ step=self.step,
503
+ )
504
+
505
+ self.step += 1
506
+ epoch_step += 1
507
+
508
+ self.accelerator.wait_for_everyone()
509
+
510
+ epoch_sum_loss = (
511
+ epoch_sum_loss
512
+ / len(self.train_dataloader)
513
+ * self.cfg.train.gradient_accumulation_step
514
+ )
515
+
516
+ for key in epoch_losses.keys():
517
+ epoch_losses[key] = (
518
+ epoch_losses[key]
519
+ / len(self.train_dataloader)
520
+ * self.cfg.train.gradient_accumulation_step
521
+ )
522
+
523
+ return epoch_sum_loss, epoch_losses
524
+
525
+ @torch.inference_mode()
526
+ def _valid_epoch(self):
527
+ r"""Testing epoch. Should return average loss of a batch (sample) over
528
+ one epoch. See ``train_loop`` for usage.
529
+ """
530
+ if isinstance(self.model, dict):
531
+ for key in self.model.keys():
532
+ self.model[key].eval()
533
+ else:
534
+ self.model.eval()
535
+
536
+ epoch_sum_loss = 0.0
537
+ epoch_losses = dict()
538
+ for batch in tqdm(
539
+ self.valid_dataloader,
540
+ desc=f"Validating Epoch {self.epoch}",
541
+ unit="batch",
542
+ colour="GREEN",
543
+ leave=False,
544
+ dynamic_ncols=True,
545
+ smoothing=0.04,
546
+ disable=not self.accelerator.is_main_process,
547
+ ):
548
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
549
+ epoch_sum_loss += total_loss
550
+ if isinstance(valid_losses, dict):
551
+ for key, value in valid_losses.items():
552
+ if key not in epoch_losses.keys():
553
+ epoch_losses[key] = value
554
+ else:
555
+ epoch_losses[key] += value
556
+
557
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
558
+ for key in epoch_losses.keys():
559
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
560
+
561
+ self.accelerator.wait_for_everyone()
562
+
563
+ return epoch_sum_loss, epoch_losses
564
+
565
+ def _train_step(self):
566
+ pass
567
+
568
+ def _valid_step(self, batch):
569
+ pass
570
+
571
+ def _inference(self):
572
+ pass
573
+
574
+ def _is_valid_pattern(self, directory_name):
575
+ directory_name = str(directory_name)
576
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
577
+ return re.match(pattern, directory_name) is not None
578
+
579
+ def _check_basic_configs(self):
580
+ if self.cfg.train.gradient_accumulation_step <= 0:
581
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
582
+ self.logger.error(
583
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
584
+ )
585
+ self.accelerator.end_training()
586
+ raise ValueError(
587
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
588
+ )
589
+
590
+ def __dump_cfg(self, path):
591
+ os.makedirs(os.path.dirname(path), exist_ok=True)
592
+ json5.dump(
593
+ self.cfg,
594
+ open(path, "w"),
595
+ indent=4,
596
+ sort_keys=True,
597
+ ensure_ascii=False,
598
+ quote_keys=True,
599
+ )
600
+
601
+ def __check_basic_configs(self):
602
+ if self.cfg.train.gradient_accumulation_step <= 0:
603
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
604
+ self.logger.error(
605
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
606
+ )
607
+ self.accelerator.end_training()
608
+ raise ValueError(
609
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
610
+ )
611
+ # TODO: check other values
612
+
613
+ @staticmethod
614
+ def __count_parameters(model):
615
+ model_param = 0.0
616
+ if isinstance(model, dict):
617
+ for key, value in model.items():
618
+ model_param += sum(p.numel() for p in model[key].parameters())
619
+ else:
620
+ model_param = sum(p.numel() for p in model.parameters())
621
+ return model_param
622
+
623
+ def _build_speaker_lut(self):
624
+ # combine speakers
625
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
626
+ speakers = {}
627
+ else:
628
+ with open(
629
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "r"
630
+ ) as speaker_file:
631
+ speakers = json.load(speaker_file)
632
+ for dataset in self.cfg.dataset:
633
+ speaker_lut_path = os.path.join(
634
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
635
+ )
636
+ with open(speaker_lut_path, "r") as speaker_lut_path:
637
+ singer_lut = json.load(speaker_lut_path)
638
+ for singer in singer_lut.keys():
639
+ if singer not in speakers:
640
+ speakers[singer] = len(speakers)
641
+ with open(
642
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
643
+ ) as speaker_file:
644
+ json.dump(speakers, speaker_file, indent=4, ensure_ascii=False)
645
+ print(
646
+ "speakers have been dumped to {}".format(
647
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
648
+ )
649
+ )
650
+ return speakers
651
+
652
+ def _build_utt2spk_dict(self):
653
+ # combine speakers
654
+ utt2spk = {}
655
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)):
656
+ utt2spk = {}
657
+ else:
658
+ with open(
659
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "r"
660
+ ) as utt2spk_file:
661
+ for line in utt2spk_file.readlines():
662
+ utt, spk = line.strip().split("\t")
663
+ utt2spk[utt] = spk
664
+ for dataset in self.cfg.dataset:
665
+ utt2spk_dict_path = os.path.join(
666
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.utt2spk
667
+ )
668
+ with open(utt2spk_dict_path, "r") as utt2spk_dict:
669
+ for line in utt2spk_dict.readlines():
670
+ utt, spk = line.strip().split("\t")
671
+ if utt not in utt2spk.keys():
672
+ utt2spk[utt] = spk
673
+ with open(
674
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "w"
675
+ ) as utt2spk_file:
676
+ for utt, spk in utt2spk.items():
677
+ utt2spk_file.write(utt + "\t" + spk + "\n")
678
+ print(
679
+ "utterance and speaker mapper have been dumped to {}".format(
680
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)
681
+ )
682
+ )
683
+ return utt2spk
684
+
685
+ def _save_phone_symbols_file_to_exp_path(self):
686
+ phone_symbols_file = os.path.join(
687
+ self.cfg.preprocess.processed_dir,
688
+ self.cfg.dataset[0],
689
+ self.cfg.preprocess.symbols_dict,
690
+ )
691
+ phone_symbols_file_to_exp_path = os.path.join(
692
+ self.exp_dir, self.cfg.preprocess.symbols_dict
693
+ )
694
+ shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
695
+ print(
696
+ "phone symbols been dumped to {}".format(
697
+ os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
698
+ )
699
+ )
models/tts/fastspeech2/__init__.py ADDED
File without changes