SayaSS commited on
Commit
693a136
·
1 Parent(s): cf0491a

remove unnecessary files

Browse files
.pre-commit-config.yaml DELETED
@@ -1,25 +0,0 @@
1
- repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.4.0
4
- hooks:
5
- - id: check-yaml
6
- - id: end-of-file-fixer
7
- - id: trailing-whitespace
8
-
9
- - repo: https://github.com/astral-sh/ruff-pre-commit
10
- rev: v0.0.292
11
- hooks:
12
- - id: ruff
13
- args: [ --fix ]
14
-
15
- - repo: https://github.com/psf/black
16
- rev: 23.9.1
17
- hooks:
18
- - id: black
19
-
20
- - repo: https://github.com/codespell-project/codespell
21
- rev: v2.2.6
22
- hooks:
23
- - id: codespell
24
- files: ^.*\.(py|md|rst|yml)$
25
- args: [-L=fro]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1.4.3 DELETED
Binary file (330 Bytes)
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Bert Vits2
3
  emoji: 📊
4
  colorFrom: red
5
  colorTo: green
 
1
  ---
2
+ title: Bert Vits2 JP
3
  emoji: 📊
4
  colorFrom: red
5
  colorTo: green
bert/chinese-roberta-wwm-ext-large/.gitattributes DELETED
@@ -1,9 +0,0 @@
1
- *.bin.* filter=lfs diff=lfs merge=lfs -text
2
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.h5 filter=lfs diff=lfs merge=lfs -text
5
- *.tflite filter=lfs diff=lfs merge=lfs -text
6
- *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
- *.ot filter=lfs diff=lfs merge=lfs -text
8
- *.onnx filter=lfs diff=lfs merge=lfs -text
9
- *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
bert/chinese-roberta-wwm-ext-large/.gitignore DELETED
@@ -1 +0,0 @@
1
- *.bin
 
 
bert/chinese-roberta-wwm-ext-large/README.md DELETED
@@ -1,57 +0,0 @@
1
- ---
2
- language:
3
- - zh
4
- tags:
5
- - bert
6
- license: "apache-2.0"
7
- ---
8
-
9
- # Please use 'Bert' related functions to load this model!
10
-
11
- ## Chinese BERT with Whole Word Masking
12
- For further accelerating Chinese natural language processing, we provide **Chinese pre-trained BERT with Whole Word Masking**.
13
-
14
- **[Pre-Training with Whole Word Masking for Chinese BERT](https://arxiv.org/abs/1906.08101)**
15
- Yiming Cui, Wanxiang Che, Ting Liu, Bing Qin, Ziqing Yang, Shijin Wang, Guoping Hu
16
-
17
- This repository is developed based on:https://github.com/google-research/bert
18
-
19
- You may also interested in,
20
- - Chinese BERT series: https://github.com/ymcui/Chinese-BERT-wwm
21
- - Chinese MacBERT: https://github.com/ymcui/MacBERT
22
- - Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
23
- - Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
24
- - Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
25
-
26
- More resources by HFL: https://github.com/ymcui/HFL-Anthology
27
-
28
- ## Citation
29
- If you find the technical report or resource is useful, please cite the following technical report in your paper.
30
- - Primary: https://arxiv.org/abs/2004.13922
31
- ```
32
- @inproceedings{cui-etal-2020-revisiting,
33
- title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
34
- author = "Cui, Yiming and
35
- Che, Wanxiang and
36
- Liu, Ting and
37
- Qin, Bing and
38
- Wang, Shijin and
39
- Hu, Guoping",
40
- booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
41
- month = nov,
42
- year = "2020",
43
- address = "Online",
44
- publisher = "Association for Computational Linguistics",
45
- url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
46
- pages = "657--668",
47
- }
48
- ```
49
- - Secondary: https://arxiv.org/abs/1906.08101
50
- ```
51
- @article{chinese-bert-wwm,
52
- title={Pre-Training with Whole Word Masking for Chinese BERT},
53
- author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing and Wang, Shijin and Hu, Guoping},
54
- journal={arXiv preprint arXiv:1906.08101},
55
- year={2019}
56
- }
57
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bert/chinese-roberta-wwm-ext-large/added_tokens.json DELETED
@@ -1 +0,0 @@
1
- {}
 
 
bert/chinese-roberta-wwm-ext-large/config.json DELETED
@@ -1,28 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "bos_token_id": 0,
7
- "directionality": "bidi",
8
- "eos_token_id": 2,
9
- "hidden_act": "gelu",
10
- "hidden_dropout_prob": 0.1,
11
- "hidden_size": 1024,
12
- "initializer_range": 0.02,
13
- "intermediate_size": 4096,
14
- "layer_norm_eps": 1e-12,
15
- "max_position_embeddings": 512,
16
- "model_type": "bert",
17
- "num_attention_heads": 16,
18
- "num_hidden_layers": 24,
19
- "output_past": true,
20
- "pad_token_id": 0,
21
- "pooler_fc_size": 768,
22
- "pooler_num_attention_heads": 12,
23
- "pooler_num_fc_layers": 3,
24
- "pooler_size_per_head": 128,
25
- "pooler_type": "first_token_transform",
26
- "type_vocab_size": 2,
27
- "vocab_size": 21128
28
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bert/chinese-roberta-wwm-ext-large/special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
 
 
bert/chinese-roberta-wwm-ext-large/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
bert/chinese-roberta-wwm-ext-large/tokenizer_config.json DELETED
@@ -1 +0,0 @@
1
- {"init_inputs": []}
 
 
bert/chinese-roberta-wwm-ext-large/vocab.txt DELETED
The diff for this file is too large to render. See raw diff
 
bert_gen.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- from multiprocessing import Pool
3
- import commons
4
- import utils
5
- from tqdm import tqdm
6
- from text import cleaned_text_to_sequence, get_bert
7
- import argparse
8
- import torch.multiprocessing as mp
9
-
10
- import os
11
- os.environ['http_proxy'] = 'http://localhost:11796'
12
- os.environ['https_proxy'] = 'http://localhost:11796'
13
- def process_line(line):
14
- rank = mp.current_process()._identity
15
- rank = rank[0] if len(rank) > 0 else 0
16
- if torch.cuda.is_available():
17
- gpu_id = rank % torch.cuda.device_count()
18
- device = torch.device(f"cuda:{gpu_id}")
19
- wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
20
- phone = phones.split(" ")
21
- tone = [int(i) for i in tone.split(" ")]
22
- word2ph = [int(i) for i in word2ph.split(" ")]
23
- word2ph = [i for i in word2ph]
24
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
25
-
26
- phone = commons.intersperse(phone, 0)
27
- tone = commons.intersperse(tone, 0)
28
- language = commons.intersperse(language, 0)
29
- for i in range(len(word2ph)):
30
- word2ph[i] = word2ph[i] * 2
31
- word2ph[0] += 1
32
-
33
- bert_path = wav_path.replace(".wav", ".bert.pt")
34
-
35
- try:
36
- bert = torch.load(bert_path)
37
- assert bert.shape[-1] == len(phone)
38
- except Exception:
39
- bert = get_bert(text, word2ph, language_str, device)
40
- assert bert.shape[-1] == len(phone)
41
- torch.save(bert, bert_path)
42
-
43
-
44
- if __name__ == "__main__":
45
- parser = argparse.ArgumentParser()
46
- parser.add_argument("-c", "--config", type=str, default="configs/config.json")
47
- parser.add_argument("--num_processes", type=int, default=2)
48
- args = parser.parse_args()
49
- config_path = args.config
50
- hps = utils.get_hparams_from_file(config_path)
51
- lines = []
52
- with open(hps.data.training_files, encoding="utf-8") as f:
53
- lines.extend(f.readlines())
54
-
55
- with open(hps.data.validation_files, encoding="utf-8") as f:
56
- lines.extend(f.readlines())
57
-
58
- num_processes = args.num_processes
59
- with Pool(processes=num_processes) as pool:
60
- for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
61
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/config.json DELETED
@@ -1,197 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 20,
4
- "eval_interval": 500,
5
- "seed": 52,
6
- "epochs": 10000,
7
- "learning_rate": 0.0001,
8
- "betas": [
9
- 0.8,
10
- 0.99
11
- ],
12
- "eps": 1e-09,
13
- "batch_size": 4,
14
- "fp16_run": false,
15
- "lr_decay": 0.999875,
16
- "segment_size": 16384,
17
- "init_lr_ratio": 1,
18
- "warmup_epochs": 0,
19
- "c_mel": 45,
20
- "c_kl": 1.0,
21
- "skip_optimizer": true
22
- },
23
- "data": {
24
- "training_files": "filelists/train.list",
25
- "validation_files": "filelists/val.list",
26
- "max_wav_value": 32768.0,
27
- "sampling_rate": 44100,
28
- "filter_length": 2048,
29
- "hop_length": 512,
30
- "win_length": 2048,
31
- "n_mel_channels": 128,
32
- "mel_fmin": 0.0,
33
- "mel_fmax": null,
34
- "add_blank": true,
35
- "n_speakers": 256,
36
- "cleaned_text": true,
37
- "spk2id": {
38
- "特别周": 0,
39
- "无声铃鹿": 1,
40
- "丸善斯基": 2,
41
- "富士奇迹": 3,
42
- "东海帝皇": 4,
43
- "小栗帽": 5,
44
- "黄金船": 6,
45
- "伏特加": 7,
46
- "大和赤骥": 8,
47
- "菱亚马逊": 9,
48
- "草上飞": 10,
49
- "大树快车": 11,
50
- "目白麦昆": 12,
51
- "神鹰": 13,
52
- "鲁道夫象征": 14,
53
- "好歌剧": 15,
54
- "成田白仁": 16,
55
- "爱丽数码": 17,
56
- "美妙姿势": 18,
57
- "摩耶重炮": 19,
58
- "玉藻十字": 20,
59
- "琵琶晨光": 21,
60
- "目白赖恩": 22,
61
- "美浦波旁": 23,
62
- "雪中美人": 24,
63
- "米浴": 25,
64
- "爱丽速子": 26,
65
- "爱慕织姬": 27,
66
- "曼城茶座": 28,
67
- "气槽": 29,
68
- "星云天空": 30,
69
- "菱曙": 31,
70
- "艾尼斯风神": 32,
71
- "稻荷一": 33,
72
- "空中神宫": 34,
73
- "川上公主": 35,
74
- "黄金城": 36,
75
- "真机伶": 37,
76
- "荣进闪耀": 38,
77
- "采珠": 39,
78
- "新光风": 40,
79
- "超级小海湾": 41,
80
- "荒漠英雄": 42,
81
- "东瀛佐敦": 43,
82
- "中山庆典": 44,
83
- "成田大进": 45,
84
- "西野花": 46,
85
- "醒目飞鹰": 47,
86
- "春乌拉拉": 48,
87
- "青竹回忆": 49,
88
- "待兼福来": 50,
89
- "Mr CB": 51,
90
- "美丽周日": 52,
91
- "名将怒涛": 53,
92
- "帝王光辉": 54,
93
- "待兼诗歌剧": 55,
94
- "生野狄杜斯": 56,
95
- "优秀素质": 57,
96
- "双涡轮": 58,
97
- "目白多伯": 59,
98
- "目白善信": 60,
99
- "大拓太阳神": 61,
100
- "北部玄驹": 62,
101
- "目白阿尔丹": 63,
102
- "八重无敌": 64,
103
- "里见光钻": 65,
104
- "天狼星象征": 66,
105
- "樱花桂冠": 67,
106
- "成田路": 68,
107
- "也文摄辉": 69,
108
- "吉兆": 70,
109
- "鹤丸刚志": 71,
110
- "谷野美酒": 72,
111
- "第一红宝石": 73,
112
- "目白高峰": 74,
113
- "真弓快车": 75,
114
- "里见皇冠": 76,
115
- "高尚骏逸": 77,
116
- "凯斯奇迹": 78,
117
- "森林宝穴": 79,
118
- "小林力奇": 80,
119
- "奇瑞骏": 81,
120
- "葛城王牌": 82,
121
- "新宇宙": 83,
122
- "菱钻奇宝": 84,
123
- "望族": 85,
124
- "骏川手纲": 86,
125
- "秋川弥生": 87,
126
- "乙名史悦子": 88,
127
- "桐生院葵": 89,
128
- "安心泽刺刺美": 90,
129
- "达利阿拉伯": 91,
130
- "高多芬柏布": 92,
131
- "佐岳五月": 93,
132
- "胜利奖券": 94,
133
- "樱花进王": 95,
134
- "东商变革": 96,
135
- "微光飞驹": 97,
136
- "樱花千代王": 98,
137
- "跳舞城": 99,
138
- "樫本理子": 100,
139
- "明亮圣辉": 101,
140
- "拜耶土耳其": 102
141
- }
142
- },
143
- "model": {
144
- "use_spk_conditioned_encoder": true,
145
- "use_noise_scaled_mas": true,
146
- "use_mel_posterior_encoder": false,
147
- "use_duration_discriminator": true,
148
- "inter_channels": 192,
149
- "hidden_channels": 192,
150
- "filter_channels": 768,
151
- "n_heads": 2,
152
- "n_layers": 6,
153
- "kernel_size": 3,
154
- "p_dropout": 0.1,
155
- "resblock": "1",
156
- "resblock_kernel_sizes": [
157
- 3,
158
- 7,
159
- 11
160
- ],
161
- "resblock_dilation_sizes": [
162
- [
163
- 1,
164
- 3,
165
- 5
166
- ],
167
- [
168
- 1,
169
- 3,
170
- 5
171
- ],
172
- [
173
- 1,
174
- 3,
175
- 5
176
- ]
177
- ],
178
- "upsample_rates": [
179
- 8,
180
- 8,
181
- 2,
182
- 2,
183
- 2
184
- ],
185
- "upsample_initial_channel": 512,
186
- "upsample_kernel_sizes": [
187
- 16,
188
- 16,
189
- 8,
190
- 2,
191
- 2
192
- ],
193
- "n_layers_q": 3,
194
- "use_spectral_norm": false,
195
- "gin_channels": 256
196
- }
197
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils.py DELETED
@@ -1,406 +0,0 @@
1
- import os
2
- import random
3
- import torch
4
- import torch.utils.data
5
- from tqdm import tqdm
6
- from loguru import logger
7
- import commons
8
- from mel_processing import spectrogram_torch, mel_spectrogram_torch
9
- from utils import load_wav_to_torch, load_filepaths_and_text
10
- from text import cleaned_text_to_sequence, get_bert
11
-
12
- """Multi speaker version"""
13
-
14
-
15
- class TextAudioSpeakerLoader(torch.utils.data.Dataset):
16
- """
17
- 1) loads audio, speaker_id, text pairs
18
- 2) normalizes text and converts them to sequences of integers
19
- 3) computes spectrograms from audio files.
20
- """
21
-
22
- def __init__(self, audiopaths_sid_text, hparams):
23
- self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
24
- self.max_wav_value = hparams.max_wav_value
25
- self.sampling_rate = hparams.sampling_rate
26
- self.filter_length = hparams.filter_length
27
- self.hop_length = hparams.hop_length
28
- self.win_length = hparams.win_length
29
- self.sampling_rate = hparams.sampling_rate
30
- self.spk_map = hparams.spk2id
31
- self.hparams = hparams
32
-
33
- self.use_mel_spec_posterior = getattr(
34
- hparams, "use_mel_posterior_encoder", False
35
- )
36
- if self.use_mel_spec_posterior:
37
- self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
38
-
39
- self.cleaned_text = getattr(hparams, "cleaned_text", False)
40
-
41
- self.add_blank = hparams.add_blank
42
- self.min_text_len = getattr(hparams, "min_text_len", 1)
43
- self.max_text_len = getattr(hparams, "max_text_len", 300)
44
-
45
- random.seed(1234)
46
- random.shuffle(self.audiopaths_sid_text)
47
- self._filter()
48
-
49
- def _filter(self):
50
- """
51
- Filter text & store spec lengths
52
- """
53
- # Store spectrogram lengths for Bucketing
54
- # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
55
- # spec_length = wav_length // hop_length
56
-
57
- audiopaths_sid_text_new = []
58
- lengths = []
59
- skipped = 0
60
- logger.info("Init dataset...")
61
- for _id, spk, language, text, phones, tone, word2ph in tqdm(
62
- self.audiopaths_sid_text
63
- ):
64
- audiopath = f"{_id}"
65
- if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
66
- phones = phones.split(" ")
67
- tone = [int(i) for i in tone.split(" ")]
68
- word2ph = [int(i) for i in word2ph.split(" ")]
69
- audiopaths_sid_text_new.append(
70
- [audiopath, spk, language, text, phones, tone, word2ph]
71
- )
72
- lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
73
- else:
74
- skipped += 1
75
- logger.info(
76
- "skipped: "
77
- + str(skipped)
78
- + ", total: "
79
- + str(len(self.audiopaths_sid_text))
80
- )
81
- self.audiopaths_sid_text = audiopaths_sid_text_new
82
- self.lengths = lengths
83
-
84
- def get_audio_text_speaker_pair(self, audiopath_sid_text):
85
- # separate filename, speaker_id and text
86
- audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
87
-
88
- bert, ja_bert, phones, tone, language = self.get_text(
89
- text, word2ph, phones, tone, language, audiopath
90
- )
91
-
92
- spec, wav = self.get_audio(audiopath)
93
- sid = torch.LongTensor([int(self.spk_map[sid])])
94
- return (phones, spec, wav, sid, tone, language, bert, ja_bert)
95
-
96
- def get_audio(self, filename):
97
- audio, sampling_rate = load_wav_to_torch(filename)
98
- if sampling_rate != self.sampling_rate:
99
- raise ValueError(
100
- "{} {} SR doesn't match target {} SR".format(
101
- filename, sampling_rate, self.sampling_rate
102
- )
103
- )
104
- audio_norm = audio / self.max_wav_value
105
- audio_norm = audio_norm.unsqueeze(0)
106
- spec_filename = filename.replace(".wav", ".spec.pt")
107
- if self.use_mel_spec_posterior:
108
- spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
109
- try:
110
- spec = torch.load(spec_filename)
111
- except:
112
- if self.use_mel_spec_posterior:
113
- spec = mel_spectrogram_torch(
114
- audio_norm,
115
- self.filter_length,
116
- self.n_mel_channels,
117
- self.sampling_rate,
118
- self.hop_length,
119
- self.win_length,
120
- self.hparams.mel_fmin,
121
- self.hparams.mel_fmax,
122
- center=False,
123
- )
124
- else:
125
- spec = spectrogram_torch(
126
- audio_norm,
127
- self.filter_length,
128
- self.sampling_rate,
129
- self.hop_length,
130
- self.win_length,
131
- center=False,
132
- )
133
- spec = torch.squeeze(spec, 0)
134
- torch.save(spec, spec_filename)
135
- return spec, audio_norm
136
-
137
- def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
138
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
139
- if self.add_blank:
140
- phone = commons.intersperse(phone, 0)
141
- tone = commons.intersperse(tone, 0)
142
- language = commons.intersperse(language, 0)
143
- for i in range(len(word2ph)):
144
- word2ph[i] = word2ph[i] * 2
145
- word2ph[0] += 1
146
- bert_path = wav_path.replace(".wav", ".bert.pt")
147
- try:
148
- bert = torch.load(bert_path)
149
- assert bert.shape[-1] == len(phone)
150
- except:
151
- bert = get_bert(text, word2ph, language_str)
152
- torch.save(bert, bert_path)
153
- assert bert.shape[-1] == len(phone), phone
154
-
155
- if language_str == "ZH":
156
- bert = bert
157
- ja_bert = torch.zeros(768, len(phone))
158
- elif language_str == "JP":
159
- ja_bert = bert
160
- bert = torch.zeros(1024, len(phone))
161
- else:
162
- bert = torch.zeros(1024, len(phone))
163
- ja_bert = torch.zeros(768, len(phone))
164
- assert bert.shape[-1] == len(phone), (
165
- bert.shape,
166
- len(phone),
167
- sum(word2ph),
168
- p1,
169
- p2,
170
- t1,
171
- t2,
172
- pold,
173
- pold2,
174
- word2ph,
175
- text,
176
- w2pho,
177
- )
178
- phone = torch.LongTensor(phone)
179
- tone = torch.LongTensor(tone)
180
- language = torch.LongTensor(language)
181
- return bert, ja_bert, phone, tone, language
182
-
183
- def get_sid(self, sid):
184
- sid = torch.LongTensor([int(sid)])
185
- return sid
186
-
187
- def __getitem__(self, index):
188
- return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
189
-
190
- def __len__(self):
191
- return len(self.audiopaths_sid_text)
192
-
193
-
194
- class TextAudioSpeakerCollate:
195
- """Zero-pads model inputs and targets"""
196
-
197
- def __init__(self, return_ids=False):
198
- self.return_ids = return_ids
199
-
200
- def __call__(self, batch):
201
- """Collate's training batch from normalized text, audio and speaker identities
202
- PARAMS
203
- ------
204
- batch: [text_normalized, spec_normalized, wav_normalized, sid]
205
- """
206
- # Right zero-pad all one-hot text sequences to max input length
207
- _, ids_sorted_decreasing = torch.sort(
208
- torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
209
- )
210
-
211
- max_text_len = max([len(x[0]) for x in batch])
212
- max_spec_len = max([x[1].size(1) for x in batch])
213
- max_wav_len = max([x[2].size(1) for x in batch])
214
-
215
- text_lengths = torch.LongTensor(len(batch))
216
- spec_lengths = torch.LongTensor(len(batch))
217
- wav_lengths = torch.LongTensor(len(batch))
218
- sid = torch.LongTensor(len(batch))
219
-
220
- text_padded = torch.LongTensor(len(batch), max_text_len)
221
- tone_padded = torch.LongTensor(len(batch), max_text_len)
222
- language_padded = torch.LongTensor(len(batch), max_text_len)
223
- bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
224
- ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len)
225
-
226
- spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
227
- wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
228
- text_padded.zero_()
229
- tone_padded.zero_()
230
- language_padded.zero_()
231
- spec_padded.zero_()
232
- wav_padded.zero_()
233
- bert_padded.zero_()
234
- ja_bert_padded.zero_()
235
- for i in range(len(ids_sorted_decreasing)):
236
- row = batch[ids_sorted_decreasing[i]]
237
-
238
- text = row[0]
239
- text_padded[i, : text.size(0)] = text
240
- text_lengths[i] = text.size(0)
241
-
242
- spec = row[1]
243
- spec_padded[i, :, : spec.size(1)] = spec
244
- spec_lengths[i] = spec.size(1)
245
-
246
- wav = row[2]
247
- wav_padded[i, :, : wav.size(1)] = wav
248
- wav_lengths[i] = wav.size(1)
249
-
250
- sid[i] = row[3]
251
-
252
- tone = row[4]
253
- tone_padded[i, : tone.size(0)] = tone
254
-
255
- language = row[5]
256
- language_padded[i, : language.size(0)] = language
257
-
258
- bert = row[6]
259
- bert_padded[i, :, : bert.size(1)] = bert
260
-
261
- ja_bert = row[7]
262
- ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
263
-
264
- return (
265
- text_padded,
266
- text_lengths,
267
- spec_padded,
268
- spec_lengths,
269
- wav_padded,
270
- wav_lengths,
271
- sid,
272
- tone_padded,
273
- language_padded,
274
- bert_padded,
275
- ja_bert_padded,
276
- )
277
-
278
-
279
- class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
280
- """
281
- Maintain similar input lengths in a batch.
282
- Length groups are specified by boundaries.
283
- Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
284
-
285
- It removes samples which are not included in the boundaries.
286
- Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
287
- """
288
-
289
- def __init__(
290
- self,
291
- dataset,
292
- batch_size,
293
- boundaries,
294
- num_replicas=None,
295
- rank=None,
296
- shuffle=True,
297
- ):
298
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
299
- self.lengths = dataset.lengths
300
- self.batch_size = batch_size
301
- self.boundaries = boundaries
302
-
303
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
304
- self.total_size = sum(self.num_samples_per_bucket)
305
- self.num_samples = self.total_size // self.num_replicas
306
-
307
- def _create_buckets(self):
308
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
309
- for i in range(len(self.lengths)):
310
- length = self.lengths[i]
311
- idx_bucket = self._bisect(length)
312
- if idx_bucket != -1:
313
- buckets[idx_bucket].append(i)
314
-
315
- try:
316
- for i in range(len(buckets) - 1, 0, -1):
317
- if len(buckets[i]) == 0:
318
- buckets.pop(i)
319
- self.boundaries.pop(i + 1)
320
- assert all(len(bucket) > 0 for bucket in buckets)
321
- # When one bucket is not traversed
322
- except Exception as e:
323
- print("Bucket warning ", e)
324
- for i in range(len(buckets) - 1, -1, -1):
325
- if len(buckets[i]) == 0:
326
- buckets.pop(i)
327
- self.boundaries.pop(i + 1)
328
-
329
- num_samples_per_bucket = []
330
- for i in range(len(buckets)):
331
- len_bucket = len(buckets[i])
332
- total_batch_size = self.num_replicas * self.batch_size
333
- rem = (
334
- total_batch_size - (len_bucket % total_batch_size)
335
- ) % total_batch_size
336
- num_samples_per_bucket.append(len_bucket + rem)
337
- return buckets, num_samples_per_bucket
338
-
339
- def __iter__(self):
340
- # deterministically shuffle based on epoch
341
- g = torch.Generator()
342
- g.manual_seed(self.epoch)
343
-
344
- indices = []
345
- if self.shuffle:
346
- for bucket in self.buckets:
347
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
348
- else:
349
- for bucket in self.buckets:
350
- indices.append(list(range(len(bucket))))
351
-
352
- batches = []
353
- for i in range(len(self.buckets)):
354
- bucket = self.buckets[i]
355
- len_bucket = len(bucket)
356
- if len_bucket == 0:
357
- continue
358
- ids_bucket = indices[i]
359
- num_samples_bucket = self.num_samples_per_bucket[i]
360
-
361
- # add extra samples to make it evenly divisible
362
- rem = num_samples_bucket - len_bucket
363
- ids_bucket = (
364
- ids_bucket
365
- + ids_bucket * (rem // len_bucket)
366
- + ids_bucket[: (rem % len_bucket)]
367
- )
368
-
369
- # subsample
370
- ids_bucket = ids_bucket[self.rank :: self.num_replicas]
371
-
372
- # batching
373
- for j in range(len(ids_bucket) // self.batch_size):
374
- batch = [
375
- bucket[idx]
376
- for idx in ids_bucket[
377
- j * self.batch_size : (j + 1) * self.batch_size
378
- ]
379
- ]
380
- batches.append(batch)
381
-
382
- if self.shuffle:
383
- batch_ids = torch.randperm(len(batches), generator=g).tolist()
384
- batches = [batches[i] for i in batch_ids]
385
- self.batches = batches
386
-
387
- assert len(self.batches) * self.batch_size == self.num_samples
388
- return iter(self.batches)
389
-
390
- def _bisect(self, x, lo=0, hi=None):
391
- if hi is None:
392
- hi = len(self.boundaries) - 1
393
-
394
- if hi > lo:
395
- mid = (hi + lo) // 2
396
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
397
- return mid
398
- elif x <= self.boundaries[mid]:
399
- return self._bisect(x, lo, mid)
400
- else:
401
- return self._bisect(x, mid + 1, hi)
402
- else:
403
- return -1
404
-
405
- def __len__(self):
406
- return self.num_samples // self.batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_logs.txt DELETED
The diff for this file is too large to render. See raw diff
 
losses.py DELETED
@@ -1,58 +0,0 @@
1
- import torch
2
-
3
-
4
- def feature_loss(fmap_r, fmap_g):
5
- loss = 0
6
- for dr, dg in zip(fmap_r, fmap_g):
7
- for rl, gl in zip(dr, dg):
8
- rl = rl.float().detach()
9
- gl = gl.float()
10
- loss += torch.mean(torch.abs(rl - gl))
11
-
12
- return loss * 2
13
-
14
-
15
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
- loss = 0
17
- r_losses = []
18
- g_losses = []
19
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
- dr = dr.float()
21
- dg = dg.float()
22
- r_loss = torch.mean((1 - dr) ** 2)
23
- g_loss = torch.mean(dg**2)
24
- loss += r_loss + g_loss
25
- r_losses.append(r_loss.item())
26
- g_losses.append(g_loss.item())
27
-
28
- return loss, r_losses, g_losses
29
-
30
-
31
- def generator_loss(disc_outputs):
32
- loss = 0
33
- gen_losses = []
34
- for dg in disc_outputs:
35
- dg = dg.float()
36
- l = torch.mean((1 - dg) ** 2)
37
- gen_losses.append(l)
38
- loss += l
39
-
40
- return loss, gen_losses
41
-
42
-
43
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
- """
45
- z_p, logs_q: [b, h, t_t]
46
- m_p, logs_p: [b, h, t_t]
47
- """
48
- z_p = z_p.float()
49
- logs_q = logs_q.float()
50
- m_p = m_p.float()
51
- logs_p = logs_p.float()
52
- z_mask = z_mask.float()
53
-
54
- kl = logs_p - logs_q - 0.5
55
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
- kl = torch.sum(kl * z_mask)
57
- l = kl / torch.sum(z_mask)
58
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocess_text.py DELETED
@@ -1,107 +0,0 @@
1
- import json
2
- from collections import defaultdict
3
- from random import shuffle
4
- from typing import Optional
5
-
6
- from tqdm import tqdm
7
- import click
8
- from text.cleaner import clean_text
9
-
10
-
11
- @click.command()
12
- @click.option(
13
- "--transcription-path",
14
- default="filelists/genshin.list",
15
- type=click.Path(exists=True, file_okay=True, dir_okay=False),
16
- )
17
- @click.option("--cleaned-path", default=None)
18
- @click.option("--train-path", default="filelists/train.list")
19
- @click.option("--val-path", default="filelists/val.list")
20
- @click.option(
21
- "--config-path",
22
- default="configs/config.json",
23
- type=click.Path(exists=True, file_okay=True, dir_okay=False),
24
- )
25
- @click.option("--val-per-spk", default=4)
26
- @click.option("--max-val-total", default=8)
27
- @click.option("--clean/--no-clean", default=True)
28
- def main(
29
- transcription_path: str,
30
- cleaned_path: Optional[str],
31
- train_path: str,
32
- val_path: str,
33
- config_path: str,
34
- val_per_spk: int,
35
- max_val_total: int,
36
- clean: bool,
37
- ):
38
- if cleaned_path is None:
39
- cleaned_path = transcription_path + ".cleaned"
40
-
41
- if clean:
42
- errors = 0
43
- out_file = open(cleaned_path, "w", encoding="utf-8")
44
- for line in tqdm(open(transcription_path, encoding="utf-8").readlines()):
45
- try:
46
- utt, spk, language, text = line.strip().split("|")
47
- norm_text, phones, tones, word2ph = clean_text(text, language)
48
- out_file.write(
49
- "{}|{}|{}|{}|{}|{}|{}\n".format(
50
- utt,
51
- spk,
52
- language,
53
- norm_text,
54
- " ".join(phones),
55
- " ".join([str(i) for i in tones]),
56
- " ".join([str(i) for i in word2ph]),
57
- )
58
- )
59
- except Exception as error:
60
- errors += 1
61
- print("err!", line, error)
62
- print("errors:", errors)
63
- out_file.close()
64
-
65
- transcription_path = cleaned_path
66
-
67
- spk_utt_map = defaultdict(list)
68
- spk_id_map = {}
69
- current_sid = 0
70
-
71
- with open(transcription_path, encoding="utf-8") as f:
72
- for line in f.readlines():
73
- utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
74
- spk_utt_map[spk].append(line)
75
-
76
- if spk not in spk_id_map.keys():
77
- spk_id_map[spk] = current_sid
78
- current_sid += 1
79
-
80
- train_list = []
81
- val_list = []
82
-
83
- for spk, utts in spk_utt_map.items():
84
- shuffle(utts)
85
- val_list += utts[:val_per_spk]
86
- train_list += utts[val_per_spk:]
87
-
88
- if len(val_list) > max_val_total:
89
- train_list += val_list[max_val_total:]
90
- val_list = val_list[:max_val_total]
91
-
92
- with open(train_path, "w", encoding="utf-8") as f:
93
- for line in train_list:
94
- f.write(line)
95
-
96
- with open(val_path, "w", encoding="utf-8") as f:
97
- for line in val_list:
98
- f.write(line)
99
-
100
- config = json.load(open(config_path, encoding="utf-8"))
101
- config["data"]["spk2id"] = spk_id_map
102
- with open(config_path, "w", encoding="utf-8") as f:
103
- json.dump(config, f, indent=2, ensure_ascii=False)
104
-
105
-
106
- if __name__ == "__main__":
107
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
resample.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import argparse
3
- import librosa
4
- from multiprocessing import Pool, cpu_count
5
-
6
- import soundfile
7
- from tqdm import tqdm
8
-
9
-
10
- def process(item):
11
- spkdir, wav_name, args = item
12
- speaker = spkdir.replace("\\", "/").split("/")[-1]
13
- wav_path = os.path.join(args.in_dir, speaker, wav_name)
14
- if os.path.exists(wav_path) and ".wav" in wav_path:
15
- os.makedirs(os.path.join(args.out_dir, speaker), exist_ok=True)
16
- wav, sr = librosa.load(wav_path, sr=args.sr)
17
- soundfile.write(os.path.join(args.out_dir, speaker, wav_name), wav, sr)
18
-
19
-
20
- if __name__ == "__main__":
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument("--sr", type=int, default=44100, help="sampling rate")
23
- parser.add_argument(
24
- "--in_dir", type=str, default="./raw", help="path to source dir"
25
- )
26
- parser.add_argument(
27
- "--out_dir", type=str, default="./dataset", help="path to target dir"
28
- )
29
- args = parser.parse_args()
30
- # processes = 8
31
- processes = cpu_count() - 2 if cpu_count() > 4 else 1
32
- pool = Pool(processes=processes)
33
-
34
- for speaker in os.listdir(args.in_dir):
35
- spk_dir = os.path.join(args.in_dir, speaker)
36
- if os.path.isdir(spk_dir):
37
- print(spk_dir)
38
- for _ in tqdm(
39
- pool.imap_unordered(
40
- process,
41
- [
42
- (spk_dir, i, args)
43
- for i in os.listdir(spk_dir)
44
- if i.endswith("wav")
45
- ],
46
- )
47
- ):
48
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_ms.py DELETED
@@ -1,596 +0,0 @@
1
- # flake8: noqa: E402
2
-
3
- import os
4
- import torch
5
- from torch.nn import functional as F
6
- from torch.utils.data import DataLoader
7
- from torch.utils.tensorboard import SummaryWriter
8
- import torch.distributed as dist
9
- from torch.nn.parallel import DistributedDataParallel as DDP
10
- from torch.cuda.amp import autocast, GradScaler
11
- from tqdm import tqdm
12
- import logging
13
-
14
- logging.getLogger("numba").setLevel(logging.WARNING)
15
- import commons
16
- import utils
17
- from data_utils import (
18
- TextAudioSpeakerLoader,
19
- TextAudioSpeakerCollate,
20
- DistributedBucketSampler,
21
- )
22
- from models import (
23
- SynthesizerTrn,
24
- MultiPeriodDiscriminator,
25
- DurationDiscriminator,
26
- )
27
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
28
- from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
29
- from text.symbols import symbols
30
-
31
- torch.backends.cuda.matmul.allow_tf32 = True
32
- torch.backends.cudnn.allow_tf32 = (
33
- True # If encontered training problem,please try to disable TF32.
34
- )
35
- torch.set_float32_matmul_precision("medium")
36
- torch.backends.cudnn.benchmark = True
37
- torch.backends.cuda.sdp_kernel("flash")
38
- torch.backends.cuda.enable_flash_sdp(True)
39
- torch.backends.cuda.enable_mem_efficient_sdp(
40
- True
41
- ) # Not available if torch version is lower than 2.0
42
- torch.backends.cuda.enable_math_sdp(True)
43
- global_step = 0
44
-
45
-
46
- def run():
47
- dist.init_process_group(
48
- backend="gloo",
49
- init_method='tcp://127.0.0.1:11451', # Due to some training problem,we proposed to use gloo instead of nccl.
50
- rank=0,
51
- world_size=1,
52
- ) # Use torchrun instead of mp.spawn
53
- rank = dist.get_rank()
54
- n_gpus = dist.get_world_size()
55
- hps = utils.get_hparams()
56
- torch.manual_seed(hps.train.seed)
57
- torch.cuda.set_device(rank)
58
- global global_step
59
- if rank == 0:
60
- logger = utils.get_logger(hps.model_dir)
61
- logger.info(hps)
62
- utils.check_git_hash(hps.model_dir)
63
- writer = SummaryWriter(log_dir=hps.model_dir)
64
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
65
- train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
66
- train_sampler = DistributedBucketSampler(
67
- train_dataset,
68
- hps.train.batch_size,
69
- [32, 300, 400, 500, 600, 700, 800, 900, 1000],
70
- num_replicas=n_gpus,
71
- rank=rank,
72
- shuffle=True,
73
- )
74
- collate_fn = TextAudioSpeakerCollate()
75
- train_loader = DataLoader(
76
- train_dataset,
77
- num_workers=16,
78
- shuffle=False,
79
- pin_memory=True,
80
- collate_fn=collate_fn,
81
- batch_sampler=train_sampler,
82
- persistent_workers=True,
83
- prefetch_factor=4,
84
- ) # DataLoader config could be adjusted.
85
- if rank == 0:
86
- eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
87
- eval_loader = DataLoader(
88
- eval_dataset,
89
- num_workers=0,
90
- shuffle=False,
91
- batch_size=1,
92
- pin_memory=True,
93
- drop_last=False,
94
- collate_fn=collate_fn,
95
- )
96
- if (
97
- "use_noise_scaled_mas" in hps.model.keys()
98
- and hps.model.use_noise_scaled_mas is True
99
- ):
100
- print("Using noise scaled MAS for VITS2")
101
- mas_noise_scale_initial = 0.01
102
- noise_scale_delta = 2e-6
103
- else:
104
- print("Using normal MAS for VITS1")
105
- mas_noise_scale_initial = 0.0
106
- noise_scale_delta = 0.0
107
- if (
108
- "use_duration_discriminator" in hps.model.keys()
109
- and hps.model.use_duration_discriminator is True
110
- ):
111
- print("Using duration discriminator for VITS2")
112
- net_dur_disc = DurationDiscriminator(
113
- hps.model.hidden_channels,
114
- hps.model.hidden_channels,
115
- 3,
116
- 0.1,
117
- gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
118
- ).cuda(rank)
119
- if (
120
- "use_spk_conditioned_encoder" in hps.model.keys()
121
- and hps.model.use_spk_conditioned_encoder is True
122
- ):
123
- if hps.data.n_speakers == 0:
124
- raise ValueError(
125
- "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
126
- )
127
- else:
128
- print("Using normal encoder for VITS1")
129
-
130
- net_g = SynthesizerTrn(
131
- len(symbols),
132
- hps.data.filter_length // 2 + 1,
133
- hps.train.segment_size // hps.data.hop_length,
134
- n_speakers=hps.data.n_speakers,
135
- mas_noise_scale_initial=mas_noise_scale_initial,
136
- noise_scale_delta=noise_scale_delta,
137
- **hps.model,
138
- ).cuda(rank)
139
-
140
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
141
- optim_g = torch.optim.AdamW(
142
- filter(lambda p: p.requires_grad, net_g.parameters()),
143
- hps.train.learning_rate,
144
- betas=hps.train.betas,
145
- eps=hps.train.eps,
146
- )
147
- optim_d = torch.optim.AdamW(
148
- net_d.parameters(),
149
- hps.train.learning_rate,
150
- betas=hps.train.betas,
151
- eps=hps.train.eps,
152
- )
153
- if net_dur_disc is not None:
154
- optim_dur_disc = torch.optim.AdamW(
155
- net_dur_disc.parameters(),
156
- hps.train.learning_rate,
157
- betas=hps.train.betas,
158
- eps=hps.train.eps,
159
- )
160
- else:
161
- optim_dur_disc = None
162
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
163
- net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
164
- if net_dur_disc is not None:
165
- net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
166
- try:
167
- if net_dur_disc is not None:
168
- _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
169
- utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
170
- net_dur_disc,
171
- optim_dur_disc,
172
- skip_optimizer=hps.train.skip_optimizer
173
- if "skip_optimizer" in hps.train
174
- else True,
175
- )
176
- _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
177
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
178
- net_g,
179
- optim_g,
180
- skip_optimizer=hps.train.skip_optimizer
181
- if "skip_optimizer" in hps.train
182
- else True,
183
- )
184
- _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
185
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
186
- net_d,
187
- optim_d,
188
- skip_optimizer=hps.train.skip_optimizer
189
- if "skip_optimizer" in hps.train
190
- else True,
191
- )
192
- if not optim_g.param_groups[0].get("initial_lr"):
193
- optim_g.param_groups[0]["initial_lr"] = g_resume_lr
194
- if not optim_d.param_groups[0].get("initial_lr"):
195
- optim_d.param_groups[0]["initial_lr"] = d_resume_lr
196
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
197
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
198
-
199
- epoch_str = max(epoch_str, 1)
200
- global_step = (epoch_str - 1) * len(train_loader)
201
- except Exception as e:
202
- print(e)
203
- epoch_str = 1
204
- global_step = 0
205
-
206
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
207
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
208
- )
209
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
210
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
211
- )
212
- if net_dur_disc is not None:
213
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
214
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
215
- scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
216
- optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
217
- )
218
- else:
219
- scheduler_dur_disc = None
220
- scaler = GradScaler(enabled=hps.train.fp16_run)
221
-
222
- for epoch in range(epoch_str, hps.train.epochs + 1):
223
- if rank == 0:
224
- train_and_evaluate(
225
- rank,
226
- epoch,
227
- hps,
228
- [net_g, net_d, net_dur_disc],
229
- [optim_g, optim_d, optim_dur_disc],
230
- [scheduler_g, scheduler_d, scheduler_dur_disc],
231
- scaler,
232
- [train_loader, eval_loader],
233
- logger,
234
- [writer, writer_eval],
235
- )
236
- else:
237
- train_and_evaluate(
238
- rank,
239
- epoch,
240
- hps,
241
- [net_g, net_d, net_dur_disc],
242
- [optim_g, optim_d, optim_dur_disc],
243
- [scheduler_g, scheduler_d, scheduler_dur_disc],
244
- scaler,
245
- [train_loader, None],
246
- None,
247
- None,
248
- )
249
- scheduler_g.step()
250
- scheduler_d.step()
251
- if net_dur_disc is not None:
252
- scheduler_dur_disc.step()
253
-
254
-
255
- def train_and_evaluate(
256
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
257
- ):
258
- net_g, net_d, net_dur_disc = nets
259
- optim_g, optim_d, optim_dur_disc = optims
260
- scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
261
- train_loader, eval_loader = loaders
262
- if writers is not None:
263
- writer, writer_eval = writers
264
-
265
- train_loader.batch_sampler.set_epoch(epoch)
266
- global global_step
267
-
268
- net_g.train()
269
- net_d.train()
270
- if net_dur_disc is not None:
271
- net_dur_disc.train()
272
- for batch_idx, (
273
- x,
274
- x_lengths,
275
- spec,
276
- spec_lengths,
277
- y,
278
- y_lengths,
279
- speakers,
280
- tone,
281
- language,
282
- bert,
283
- ja_bert,
284
- ) in tqdm(enumerate(train_loader)):
285
- if net_g.module.use_noise_scaled_mas:
286
- current_mas_noise_scale = (
287
- net_g.module.mas_noise_scale_initial
288
- - net_g.module.noise_scale_delta * global_step
289
- )
290
- net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
291
- x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
292
- rank, non_blocking=True
293
- )
294
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
295
- rank, non_blocking=True
296
- )
297
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
298
- rank, non_blocking=True
299
- )
300
- speakers = speakers.cuda(rank, non_blocking=True)
301
- tone = tone.cuda(rank, non_blocking=True)
302
- language = language.cuda(rank, non_blocking=True)
303
- bert = bert.cuda(rank, non_blocking=True)
304
- ja_bert = ja_bert.cuda(rank, non_blocking=True)
305
-
306
- with autocast(enabled=hps.train.fp16_run):
307
- (
308
- y_hat,
309
- l_length,
310
- attn,
311
- ids_slice,
312
- x_mask,
313
- z_mask,
314
- (z, z_p, m_p, logs_p, m_q, logs_q),
315
- (hidden_x, logw, logw_),
316
- ) = net_g(
317
- x,
318
- x_lengths,
319
- spec,
320
- spec_lengths,
321
- speakers,
322
- tone,
323
- language,
324
- bert,
325
- ja_bert,
326
- )
327
- mel = spec_to_mel_torch(
328
- spec,
329
- hps.data.filter_length,
330
- hps.data.n_mel_channels,
331
- hps.data.sampling_rate,
332
- hps.data.mel_fmin,
333
- hps.data.mel_fmax,
334
- )
335
- y_mel = commons.slice_segments(
336
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
337
- )
338
- y_hat_mel = mel_spectrogram_torch(
339
- y_hat.squeeze(1),
340
- hps.data.filter_length,
341
- hps.data.n_mel_channels,
342
- hps.data.sampling_rate,
343
- hps.data.hop_length,
344
- hps.data.win_length,
345
- hps.data.mel_fmin,
346
- hps.data.mel_fmax,
347
- )
348
-
349
- y = commons.slice_segments(
350
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
351
- ) # slice
352
-
353
- # Discriminator
354
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
355
- with autocast(enabled=False):
356
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
357
- y_d_hat_r, y_d_hat_g
358
- )
359
- loss_disc_all = loss_disc
360
- if net_dur_disc is not None:
361
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(
362
- hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
363
- )
364
- with autocast(enabled=False):
365
- # TODO: I think need to mean using the mask, but for now, just mean all
366
- (
367
- loss_dur_disc,
368
- losses_dur_disc_r,
369
- losses_dur_disc_g,
370
- ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
371
- loss_dur_disc_all = loss_dur_disc
372
- optim_dur_disc.zero_grad()
373
- scaler.scale(loss_dur_disc_all).backward()
374
- scaler.unscale_(optim_dur_disc)
375
- commons.clip_grad_value_(net_dur_disc.parameters(), None)
376
- scaler.step(optim_dur_disc)
377
-
378
- optim_d.zero_grad()
379
- scaler.scale(loss_disc_all).backward()
380
- scaler.unscale_(optim_d)
381
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
382
- scaler.step(optim_d)
383
-
384
- with autocast(enabled=hps.train.fp16_run):
385
- # Generator
386
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
387
- if net_dur_disc is not None:
388
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
389
- with autocast(enabled=False):
390
- loss_dur = torch.sum(l_length.float())
391
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
392
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
393
-
394
- loss_fm = feature_loss(fmap_r, fmap_g)
395
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
396
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
397
- if net_dur_disc is not None:
398
- loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
399
- loss_gen_all += loss_dur_gen
400
- optim_g.zero_grad()
401
- scaler.scale(loss_gen_all).backward()
402
- scaler.unscale_(optim_g)
403
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
404
- scaler.step(optim_g)
405
- scaler.update()
406
-
407
- if rank == 0:
408
- if global_step % hps.train.log_interval == 0:
409
- lr = optim_g.param_groups[0]["lr"]
410
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
411
- logger.info(
412
- "Train Epoch: {} [{:.0f}%]".format(
413
- epoch, 100.0 * batch_idx / len(train_loader)
414
- )
415
- )
416
- logger.info([x.item() for x in losses] + [global_step, lr])
417
-
418
- scalar_dict = {
419
- "loss/g/total": loss_gen_all,
420
- "loss/d/total": loss_disc_all,
421
- "learning_rate": lr,
422
- "grad_norm_d": grad_norm_d,
423
- "grad_norm_g": grad_norm_g,
424
- }
425
- scalar_dict.update(
426
- {
427
- "loss/g/fm": loss_fm,
428
- "loss/g/mel": loss_mel,
429
- "loss/g/dur": loss_dur,
430
- "loss/g/kl": loss_kl,
431
- }
432
- )
433
- scalar_dict.update(
434
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
435
- )
436
- scalar_dict.update(
437
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
438
- )
439
- scalar_dict.update(
440
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
441
- )
442
-
443
- image_dict = {
444
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
445
- y_mel[0].data.cpu().numpy()
446
- ),
447
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
448
- y_hat_mel[0].data.cpu().numpy()
449
- ),
450
- "all/mel": utils.plot_spectrogram_to_numpy(
451
- mel[0].data.cpu().numpy()
452
- ),
453
- "all/attn": utils.plot_alignment_to_numpy(
454
- attn[0, 0].data.cpu().numpy()
455
- ),
456
- }
457
- utils.summarize(
458
- writer=writer,
459
- global_step=global_step,
460
- images=image_dict,
461
- scalars=scalar_dict,
462
- )
463
-
464
- if global_step % hps.train.eval_interval == 0:
465
- evaluate(hps, net_g, eval_loader, writer_eval)
466
- utils.save_checkpoint(
467
- net_g,
468
- optim_g,
469
- hps.train.learning_rate,
470
- epoch,
471
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
472
- )
473
- utils.save_checkpoint(
474
- net_d,
475
- optim_d,
476
- hps.train.learning_rate,
477
- epoch,
478
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
479
- )
480
- if net_dur_disc is not None:
481
- utils.save_checkpoint(
482
- net_dur_disc,
483
- optim_dur_disc,
484
- hps.train.learning_rate,
485
- epoch,
486
- os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
487
- )
488
- keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
489
- if keep_ckpts > 0:
490
- utils.clean_checkpoints(
491
- path_to_models=hps.model_dir,
492
- n_ckpts_to_keep=keep_ckpts,
493
- sort_by_time=True,
494
- )
495
-
496
- global_step += 1
497
-
498
- if rank == 0:
499
- logger.info("====> Epoch: {}".format(epoch))
500
-
501
-
502
- def evaluate(hps, generator, eval_loader, writer_eval):
503
- generator.eval()
504
- image_dict = {}
505
- audio_dict = {}
506
- print("Evaluating ...")
507
- with torch.no_grad():
508
- for batch_idx, (
509
- x,
510
- x_lengths,
511
- spec,
512
- spec_lengths,
513
- y,
514
- y_lengths,
515
- speakers,
516
- tone,
517
- language,
518
- bert,
519
- ja_bert,
520
- ) in enumerate(eval_loader):
521
- x, x_lengths = x.cuda(), x_lengths.cuda()
522
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
523
- y, y_lengths = y.cuda(), y_lengths.cuda()
524
- speakers = speakers.cuda()
525
- bert = bert.cuda()
526
- ja_bert = ja_bert.cuda()
527
- tone = tone.cuda()
528
- language = language.cuda()
529
- for use_sdp in [True, False]:
530
- y_hat, attn, mask, *_ = generator.module.infer(
531
- x,
532
- x_lengths,
533
- speakers,
534
- tone,
535
- language,
536
- bert,
537
- ja_bert,
538
- y=spec,
539
- max_len=1000,
540
- sdp_ratio=0.0 if not use_sdp else 1.0,
541
- )
542
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
543
-
544
- mel = spec_to_mel_torch(
545
- spec,
546
- hps.data.filter_length,
547
- hps.data.n_mel_channels,
548
- hps.data.sampling_rate,
549
- hps.data.mel_fmin,
550
- hps.data.mel_fmax,
551
- )
552
- y_hat_mel = mel_spectrogram_torch(
553
- y_hat.squeeze(1).float(),
554
- hps.data.filter_length,
555
- hps.data.n_mel_channels,
556
- hps.data.sampling_rate,
557
- hps.data.hop_length,
558
- hps.data.win_length,
559
- hps.data.mel_fmin,
560
- hps.data.mel_fmax,
561
- )
562
- image_dict.update(
563
- {
564
- f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
565
- y_hat_mel[0].cpu().numpy()
566
- )
567
- }
568
- )
569
- audio_dict.update(
570
- {
571
- f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
572
- 0, :, : y_hat_lengths[0]
573
- ]
574
- }
575
- )
576
- image_dict.update(
577
- {
578
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
579
- mel[0].cpu().numpy()
580
- )
581
- }
582
- )
583
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
584
-
585
- utils.summarize(
586
- writer=writer_eval,
587
- global_step=global_step,
588
- images=image_dict,
589
- audios=audio_dict,
590
- audio_sampling_rate=hps.data.sampling_rate,
591
- )
592
- generator.train()
593
-
594
-
595
- if __name__ == "__main__":
596
- run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_ms_acc.py DELETED
@@ -1,623 +0,0 @@
1
- # flake8: noqa: E402
2
-
3
- import os
4
- import torch
5
- from torch.nn import functional as F
6
- from torch.utils.data import DataLoader
7
- from torch.utils.tensorboard import SummaryWriter
8
- import torch.distributed as dist
9
- from torch.nn.parallel import DistributedDataParallel as DDP
10
- from torch.cuda.amp import autocast, GradScaler
11
- from tqdm import tqdm
12
- import logging
13
-
14
- logging.getLogger("numba").setLevel(logging.WARNING)
15
- import commons
16
- import utils
17
- from data_utils import (
18
- TextAudioSpeakerLoader,
19
- TextAudioSpeakerCollate,
20
- DistributedBucketSampler,
21
- )
22
- from models import (
23
- SynthesizerTrn,
24
- MultiPeriodDiscriminator,
25
- DurationDiscriminator,
26
- )
27
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
28
- from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
29
- from text.symbols import symbols
30
-
31
- torch.backends.cuda.matmul.allow_tf32 = True
32
- torch.backends.cudnn.allow_tf32 = (
33
- True # If encontered training problem,please try to disable TF32.
34
- )
35
- torch.set_float32_matmul_precision("medium")
36
- torch.backends.cudnn.benchmark = True
37
- torch.backends.cuda.sdp_kernel("flash")
38
- torch.backends.cuda.enable_flash_sdp(True)
39
- torch.backends.cuda.enable_mem_efficient_sdp(
40
- True
41
- ) # Not available if torch version is lower than 2.0
42
- torch.backends.cuda.enable_math_sdp(True)
43
- global_step = 0
44
-
45
-
46
- def run():
47
- dist.init_process_group(
48
- backend="gloo",
49
- init_method='tcp://127.0.0.1:11451', # Due to some training problem,we proposed to use gloo instead of nccl.
50
- rank=0,
51
- world_size=1,
52
- ) # Use torchrun instead of mp.spawn
53
- rank = dist.get_rank()
54
- n_gpus = dist.get_world_size()
55
- hps = utils.get_hparams()
56
- torch.manual_seed(hps.train.seed)
57
- torch.cuda.set_device(rank)
58
- global global_step
59
- if rank == 0:
60
- logger = utils.get_logger(hps.model_dir)
61
- logger.info(hps)
62
- utils.check_git_hash(hps.model_dir)
63
- writer = SummaryWriter(log_dir=hps.model_dir)
64
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
65
- train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
66
- train_sampler = DistributedBucketSampler(
67
- train_dataset,
68
- hps.train.batch_size,
69
- [32, 300, 400, 500, 600, 700, 800, 900, 1000],
70
- num_replicas=n_gpus,
71
- rank=rank,
72
- shuffle=True,
73
- )
74
- collate_fn = TextAudioSpeakerCollate()
75
- train_loader = DataLoader(
76
- train_dataset,
77
- num_workers=16,
78
- shuffle=False,
79
- pin_memory=True,
80
- collate_fn=collate_fn,
81
- batch_sampler=train_sampler,
82
- persistent_workers=True,
83
- prefetch_factor=4,
84
- ) # DataLoader config could be adjusted.
85
- if rank == 0:
86
- eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
87
- eval_loader = DataLoader(
88
- eval_dataset,
89
- num_workers=0,
90
- shuffle=False,
91
- batch_size=1,
92
- pin_memory=True,
93
- drop_last=False,
94
- collate_fn=collate_fn,
95
- )
96
- if (
97
- "use_noise_scaled_mas" in hps.model.keys()
98
- and hps.model.use_noise_scaled_mas is True
99
- ):
100
- print("Using noise scaled MAS for VITS2")
101
- mas_noise_scale_initial = 0.01
102
- noise_scale_delta = 2e-6
103
- else:
104
- print("Using normal MAS for VITS1")
105
- mas_noise_scale_initial = 0.0
106
- noise_scale_delta = 0.0
107
- if (
108
- "use_duration_discriminator" in hps.model.keys()
109
- and hps.model.use_duration_discriminator is True
110
- ):
111
- print("Using duration discriminator for VITS2")
112
- net_dur_disc = DurationDiscriminator(
113
- hps.model.hidden_channels,
114
- hps.model.hidden_channels,
115
- 3,
116
- 0.1,
117
- gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
118
- ).cuda(rank)
119
- if (
120
- "use_spk_conditioned_encoder" in hps.model.keys()
121
- and hps.model.use_spk_conditioned_encoder is True
122
- ):
123
- if hps.data.n_speakers == 0:
124
- raise ValueError(
125
- "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
126
- )
127
- else:
128
- print("Using normal encoder for VITS1")
129
-
130
- net_g = SynthesizerTrn(
131
- len(symbols),
132
- hps.data.filter_length // 2 + 1,
133
- hps.train.segment_size // hps.data.hop_length,
134
- n_speakers=hps.data.n_speakers,
135
- mas_noise_scale_initial=mas_noise_scale_initial,
136
- noise_scale_delta=noise_scale_delta,
137
- **hps.model,
138
- ).cuda(rank)
139
-
140
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
141
- optim_g = torch.optim.AdamW(
142
- filter(lambda p: p.requires_grad, net_g.parameters()),
143
- hps.train.learning_rate,
144
- betas=hps.train.betas,
145
- eps=hps.train.eps,
146
- )
147
- optim_d = torch.optim.AdamW(
148
- net_d.parameters(),
149
- hps.train.learning_rate,
150
- betas=hps.train.betas,
151
- eps=hps.train.eps,
152
- )
153
- if net_dur_disc is not None:
154
- optim_dur_disc = torch.optim.AdamW(
155
- net_dur_disc.parameters(),
156
- hps.train.learning_rate,
157
- betas=hps.train.betas,
158
- eps=hps.train.eps,
159
- )
160
- else:
161
- optim_dur_disc = None
162
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
163
- net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
164
- if net_dur_disc is not None:
165
- net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
166
- try:
167
- if net_dur_disc is not None:
168
- _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
169
- utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
170
- net_dur_disc,
171
- optim_dur_disc,
172
- skip_optimizer=hps.train.skip_optimizer
173
- if "skip_optimizer" in hps.train
174
- else True,
175
- )
176
- _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
177
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
178
- net_g,
179
- optim_g,
180
- skip_optimizer=hps.train.skip_optimizer
181
- if "skip_optimizer" in hps.train
182
- else True,
183
- )
184
- _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
185
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
186
- net_d,
187
- optim_d,
188
- skip_optimizer=hps.train.skip_optimizer
189
- if "skip_optimizer" in hps.train
190
- else True,
191
- )
192
- if not optim_g.param_groups[0].get("initial_lr"):
193
- optim_g.param_groups[0]["initial_lr"] = g_resume_lr
194
- if not optim_d.param_groups[0].get("initial_lr"):
195
- optim_d.param_groups[0]["initial_lr"] = d_resume_lr
196
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
197
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
198
-
199
- epoch_str = max(epoch_str, 1)
200
- global_step = (epoch_str - 1) * len(train_loader)
201
- except Exception as e:
202
- print(e)
203
- epoch_str = 1
204
- global_step = 0
205
-
206
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
207
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
208
- )
209
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
210
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
211
- )
212
- if net_dur_disc is not None:
213
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
214
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
215
- scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
216
- optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
217
- )
218
- else:
219
- scheduler_dur_disc = None
220
- scaler = GradScaler(enabled=hps.train.fp16_run)
221
-
222
-
223
-
224
-
225
- for epoch in range(epoch_str, hps.train.epochs + 1):
226
- if rank == 0:
227
- train_and_evaluate(
228
- rank,
229
- epoch,
230
- hps,
231
- [net_g, net_d, net_dur_disc],
232
- [optim_g, optim_d, optim_dur_disc],
233
- [scheduler_g, scheduler_d, scheduler_dur_disc],
234
- scaler,
235
- [train_loader, eval_loader],
236
- logger,
237
- [writer, writer_eval],
238
- )
239
- else:
240
- train_and_evaluate(
241
- rank,
242
- epoch,
243
- hps,
244
- [net_g, net_d, net_dur_disc],
245
- [optim_g, optim_d, optim_dur_disc],
246
- [scheduler_g, scheduler_d, scheduler_dur_disc],
247
- scaler,
248
- [train_loader, None],
249
- None,
250
- None,
251
- )
252
- scheduler_g.step()
253
- scheduler_d.step()
254
- if net_dur_disc is not None:
255
- scheduler_dur_disc.step()
256
-
257
-
258
- __ACCUMULATION_STEP__ = 6
259
- __CURRENT_ACCUMULATION_STEP__ = 0
260
-
261
- def train_and_evaluate(
262
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
263
- ):
264
- global __ACCUMULATION_STEP__
265
- global __CURRENT_ACCUMULATION_STEP__
266
- net_g, net_d, net_dur_disc = nets
267
- optim_g, optim_d, optim_dur_disc = optims
268
- scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
269
- train_loader, eval_loader = loaders
270
- if writers is not None:
271
- writer, writer_eval = writers
272
-
273
- train_loader.batch_sampler.set_epoch(epoch)
274
- global global_step
275
-
276
- net_g.train()
277
- net_d.train()
278
- if net_dur_disc is not None:
279
- net_dur_disc.train()
280
- for batch_idx, (
281
- x,
282
- x_lengths,
283
- spec,
284
- spec_lengths,
285
- y,
286
- y_lengths,
287
- speakers,
288
- tone,
289
- language,
290
- bert,
291
- ja_bert,
292
- ) in tqdm(enumerate(train_loader)):
293
- if net_g.module.use_noise_scaled_mas:
294
- current_mas_noise_scale = (
295
- net_g.module.mas_noise_scale_initial
296
- - net_g.module.noise_scale_delta * global_step
297
- )
298
- net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
299
- x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
300
- rank, non_blocking=True
301
- )
302
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
303
- rank, non_blocking=True
304
- )
305
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
306
- rank, non_blocking=True
307
- )
308
- speakers = speakers.cuda(rank, non_blocking=True)
309
- tone = tone.cuda(rank, non_blocking=True)
310
- language = language.cuda(rank, non_blocking=True)
311
- bert = bert.cuda(rank, non_blocking=True)
312
- ja_bert = ja_bert.cuda(rank, non_blocking=True)
313
-
314
- with autocast(enabled=hps.train.fp16_run):
315
- (
316
- y_hat,
317
- l_length,
318
- attn,
319
- ids_slice,
320
- x_mask,
321
- z_mask,
322
- (z, z_p, m_p, logs_p, m_q, logs_q),
323
- (hidden_x, logw, logw_),
324
- ) = net_g(
325
- x,
326
- x_lengths,
327
- spec,
328
- spec_lengths,
329
- speakers,
330
- tone,
331
- language,
332
- bert,
333
- ja_bert,
334
- )
335
- mel = spec_to_mel_torch(
336
- spec,
337
- hps.data.filter_length,
338
- hps.data.n_mel_channels,
339
- hps.data.sampling_rate,
340
- hps.data.mel_fmin,
341
- hps.data.mel_fmax,
342
- )
343
- y_mel = commons.slice_segments(
344
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
345
- )
346
- y_hat_mel = mel_spectrogram_torch(
347
- y_hat.squeeze(1),
348
- hps.data.filter_length,
349
- hps.data.n_mel_channels,
350
- hps.data.sampling_rate,
351
- hps.data.hop_length,
352
- hps.data.win_length,
353
- hps.data.mel_fmin,
354
- hps.data.mel_fmax,
355
- )
356
-
357
- y = commons.slice_segments(
358
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
359
- ) # slice
360
-
361
- # Discriminator
362
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
363
- with autocast(enabled=False):
364
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
365
- y_d_hat_r, y_d_hat_g
366
- )
367
- loss_disc_all = loss_disc
368
- if net_dur_disc is not None:
369
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(
370
- hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
371
- )
372
- with autocast(enabled=False):
373
- # TODO: I think need to mean using the mask, but for now, just mean all
374
- (
375
- loss_dur_disc,
376
- losses_dur_disc_r,
377
- losses_dur_disc_g,
378
- ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
379
- loss_dur_disc_all = loss_dur_disc
380
- optim_dur_disc.zero_grad()
381
- scaler.scale(loss_dur_disc_all).backward()
382
- scaler.unscale_(optim_dur_disc)
383
- commons.clip_grad_value_(net_dur_disc.parameters(), None)
384
- scaler.step(optim_dur_disc)
385
-
386
-
387
-
388
- scaler.scale(loss_disc_all/__ACCUMULATION_STEP__).backward()
389
- __CURRENT_ACCUMULATION_STEP__ += 1
390
-
391
- if __CURRENT_ACCUMULATION_STEP__ == __ACCUMULATION_STEP__:
392
- __CURRENT_ACCUMULATION_STEP__ = 0
393
-
394
- scaler.unscale_(optim_d)
395
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
396
- scaler.step(optim_d)
397
- optim_d.zero_grad()
398
-
399
-
400
-
401
-
402
- with autocast(enabled=hps.train.fp16_run):
403
- # Generator
404
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
405
- if net_dur_disc is not None:
406
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
407
- with autocast(enabled=False):
408
- loss_dur = torch.sum(l_length.float())
409
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
410
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
411
-
412
- loss_fm = feature_loss(fmap_r, fmap_g)
413
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
414
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
415
- if net_dur_disc is not None:
416
- loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
417
- loss_gen_all += loss_dur_gen
418
-
419
-
420
- scaler.scale(loss_gen_all/__ACCUMULATION_STEP__).backward()
421
- if __CURRENT_ACCUMULATION_STEP__ == __ACCUMULATION_STEP__:
422
- __CURRENT_ACCUMULATION_STEP__ = 0
423
-
424
- scaler.unscale_(optim_g)
425
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
426
- scaler.step(optim_g)
427
- scaler.update()
428
- optim_g.zero_grad()
429
-
430
-
431
-
432
-
433
- if rank == 0:
434
- if (global_step-1) % hps.train.log_interval == 0:
435
- lr = optim_g.param_groups[0]["lr"]
436
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
437
- logger.info(
438
- "Train Epoch: {} [{:.0f}%]".format(
439
- epoch, 100.0 * batch_idx / len(train_loader)
440
- )
441
- )
442
- logger.info([x.item() for x in losses] + [global_step, lr])
443
-
444
- scalar_dict = {
445
- "loss/g/total": loss_gen_all,
446
- "loss/d/total": loss_disc_all,
447
- "learning_rate": lr,
448
- "grad_norm_d": grad_norm_d,
449
- "grad_norm_g": grad_norm_g,
450
- }
451
- scalar_dict.update(
452
- {
453
- "loss/g/fm": loss_fm,
454
- "loss/g/mel": loss_mel,
455
- "loss/g/dur": loss_dur,
456
- "loss/g/kl": loss_kl,
457
- }
458
- )
459
- scalar_dict.update(
460
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
461
- )
462
- scalar_dict.update(
463
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
464
- )
465
- scalar_dict.update(
466
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
467
- )
468
-
469
- image_dict = {
470
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
471
- y_mel[0].data.cpu().numpy()
472
- ),
473
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
474
- y_hat_mel[0].data.cpu().numpy()
475
- ),
476
- "all/mel": utils.plot_spectrogram_to_numpy(
477
- mel[0].data.cpu().numpy()
478
- ),
479
- "all/attn": utils.plot_alignment_to_numpy(
480
- attn[0, 0].data.cpu().numpy()
481
- ),
482
- }
483
- utils.summarize(
484
- writer=writer,
485
- global_step=global_step,
486
- images=image_dict,
487
- scalars=scalar_dict,
488
- )
489
-
490
- if (global_step-1) % hps.train.eval_interval == 0:
491
- evaluate(hps, net_g, eval_loader, writer_eval)
492
- utils.save_checkpoint(
493
- net_g,
494
- optim_g,
495
- hps.train.learning_rate,
496
- epoch,
497
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
498
- )
499
- utils.save_checkpoint(
500
- net_d,
501
- optim_d,
502
- hps.train.learning_rate,
503
- epoch,
504
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
505
- )
506
- if net_dur_disc is not None:
507
- utils.save_checkpoint(
508
- net_dur_disc,
509
- optim_dur_disc,
510
- hps.train.learning_rate,
511
- epoch,
512
- os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
513
- )
514
- keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
515
- if keep_ckpts > 0:
516
- utils.clean_checkpoints(
517
- path_to_models=hps.model_dir,
518
- n_ckpts_to_keep=keep_ckpts,
519
- sort_by_time=True,
520
- )
521
-
522
- global_step += 1
523
-
524
- if rank == 0:
525
- logger.info("====> Epoch: {} ===>{}".format(epoch, __CURRENT_ACCUMULATION_STEP__))
526
-
527
-
528
-
529
- def evaluate(hps, generator, eval_loader, writer_eval):
530
- generator.eval()
531
- image_dict = {}
532
- audio_dict = {}
533
- print("Evaluating ...")
534
- with torch.no_grad():
535
- for batch_idx, (
536
- x,
537
- x_lengths,
538
- spec,
539
- spec_lengths,
540
- y,
541
- y_lengths,
542
- speakers,
543
- tone,
544
- language,
545
- bert,
546
- ja_bert,
547
- ) in enumerate(eval_loader):
548
- x, x_lengths = x.cuda(), x_lengths.cuda()
549
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
550
- y, y_lengths = y.cuda(), y_lengths.cuda()
551
- speakers = speakers.cuda()
552
- bert = bert.cuda()
553
- ja_bert = ja_bert.cuda()
554
- tone = tone.cuda()
555
- language = language.cuda()
556
- for use_sdp in [True, False]:
557
- y_hat, attn, mask, *_ = generator.module.infer(
558
- x,
559
- x_lengths,
560
- speakers,
561
- tone,
562
- language,
563
- bert,
564
- ja_bert,
565
- y=spec,
566
- max_len=1000,
567
- sdp_ratio=0.0 if not use_sdp else 1.0,
568
- )
569
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
570
-
571
- mel = spec_to_mel_torch(
572
- spec,
573
- hps.data.filter_length,
574
- hps.data.n_mel_channels,
575
- hps.data.sampling_rate,
576
- hps.data.mel_fmin,
577
- hps.data.mel_fmax,
578
- )
579
- y_hat_mel = mel_spectrogram_torch(
580
- y_hat.squeeze(1).float(),
581
- hps.data.filter_length,
582
- hps.data.n_mel_channels,
583
- hps.data.sampling_rate,
584
- hps.data.hop_length,
585
- hps.data.win_length,
586
- hps.data.mel_fmin,
587
- hps.data.mel_fmax,
588
- )
589
- image_dict.update(
590
- {
591
- f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
592
- y_hat_mel[0].cpu().numpy()
593
- )
594
- }
595
- )
596
- audio_dict.update(
597
- {
598
- f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
599
- 0, :, : y_hat_lengths[0]
600
- ]
601
- }
602
- )
603
- image_dict.update(
604
- {
605
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
606
- mel[0].cpu().numpy()
607
- )
608
- }
609
- )
610
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
611
-
612
- utils.summarize(
613
- writer=writer_eval,
614
- global_step=global_step,
615
- images=image_dict,
616
- audios=audio_dict,
617
- audio_sampling_rate=hps.data.sampling_rate,
618
- )
619
- generator.train()
620
-
621
-
622
- if __name__ == "__main__":
623
- run()