wasmdashai commited on
Commit
d148bcd
1 Parent(s): 686c4a5

model push

Browse files
VitsModelSplit/.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+ models--facebook--mms-tts-ara/
9
+ output/
10
+ dataset/
11
+ wandb/
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
VitsModelSplit/Arguments.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import TrainingArguments
3
+ from typing import Any, Optional
4
+ from dataclasses import dataclass, field
5
+
6
+
7
+ #.............................................
8
+
9
+ #### ARGUMENTS
10
+
11
+
12
+ @dataclass
13
+ class ModelArguments:
14
+ """
15
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
16
+ """
17
+
18
+ model_name_or_path: str = field(
19
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
20
+ )
21
+ config_name: Optional[str] = field(
22
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
23
+ )
24
+ tokenizer_name: Optional[str] = field(
25
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
26
+ )
27
+ feature_extractor_name: Optional[str] = field(
28
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
29
+ )
30
+ cache_dir: Optional[str] = field(
31
+ default=None,
32
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
33
+ )
34
+ use_fast_tokenizer: bool = field(
35
+ default=True,
36
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
37
+ )
38
+ model_revision: str = field(
39
+ default="main",
40
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
41
+ )
42
+ token: str = field(
43
+ default=None,
44
+ metadata={
45
+ "help": (
46
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
47
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
48
+ )
49
+ },
50
+ )
51
+ use_auth_token: bool = field(
52
+ default=None,
53
+ metadata={
54
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
55
+ },
56
+ )
57
+ trust_remote_code: bool = field(
58
+ default=False,
59
+ metadata={
60
+ "help": (
61
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
62
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
63
+ "execute code present on the Hub on your local machine."
64
+ )
65
+ },
66
+ )
67
+ override_speaker_embeddings: bool = field(
68
+ default=False,
69
+ metadata={
70
+ "help": (
71
+ "If `True` and if `speaker_id_column_name` is specified, it will replace current speaker embeddings with a new set of speaker embeddings."
72
+ "If the model from the checkpoint didn't have speaker embeddings, it will initialize speaker embeddings."
73
+ )
74
+ },
75
+ )
76
+
77
+ override_vocabulary_embeddings: bool = field(
78
+ default=False,
79
+ metadata={
80
+ "help": (
81
+ "If `True`, it will resize the token embeddings based on the vocabulary size of the tokenizer. In other words, use this when you use a different tokenizer than the one that was used during pretraining."
82
+ )
83
+ },
84
+ )
85
+
86
+ #.............................................................................................
87
+
88
+
89
+ @dataclass
90
+ class VITSTrainingArguments(TrainingArguments):
91
+ do_step_schedule_per_epoch: bool = field(
92
+ default=True,
93
+ metadata={
94
+ "help": (
95
+ "Whether or not to perform scheduler steps per epoch or per steps. If `True`, the scheduler will be `ExponentialLR` parametrized with `lr_decay`."
96
+ )
97
+ },
98
+ )
99
+
100
+ lr_decay: float = field(
101
+ default=0.999875,
102
+ metadata={"help": "Learning rate decay, used with `ExponentialLR` when `do_step_schedule_per_epoch`."},
103
+ )
104
+
105
+ weight_duration: float = field(default=1.0, metadata={"help": "Duration loss weight."})
106
+
107
+ weight_kl: float = field(default=1.5, metadata={"help": "KL loss weight."})
108
+
109
+ weight_mel: float = field(default=35.0, metadata={"help": "Mel-spectrogram loss weight"})
110
+
111
+ weight_disc: float = field(default=3.0, metadata={"help": "Discriminator loss weight"})
112
+
113
+ weight_gen: float = field(default=1.0, metadata={"help": "Generator loss weight"})
114
+
115
+ weight_fmaps: float = field(default=1.0, metadata={"help": "Feature map loss weight"})
116
+ d_learning_rate: float = field(default=2e-4, metadata={"help": "Feature map loss weight"})
117
+
118
+ d_adam_beta1: float = field(default=0.8, metadata={"help": "Feature map loss weight"})
119
+ d_adam_beta2: float = field(default=0.99, metadata={"help": "Feature map loss weight"})
120
+
121
+
122
+ #.............................................................................................
123
+
124
+ @dataclass
125
+ class DataTrainingArguments:
126
+ """
127
+ Arguments pertaining to what data we are going to input our model for training and eval.
128
+ """
129
+
130
+ project_name: str = field(
131
+ default="vits_finetuning",
132
+ metadata={"help": "The project name associated to this run. Useful to track your experiment."},
133
+ )
134
+ dataset_name: str = field(
135
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
136
+ )
137
+ dataset_config_name: Optional[str] = field(
138
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
139
+ )
140
+ overwrite_cache: bool = field(
141
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
142
+ )
143
+ preprocessing_num_workers: Optional[int] = field(
144
+ default=None,
145
+ metadata={"help": "The number of processes to use for the preprocessing."},
146
+ )
147
+ max_train_samples: Optional[int] = field(
148
+ default=None,
149
+ metadata={
150
+ "help": (
151
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
152
+ "value if set."
153
+ )
154
+ },
155
+ )
156
+ max_eval_samples: Optional[int] = field(
157
+ default=None,
158
+ metadata={
159
+ "help": (
160
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
161
+ "value if set."
162
+ )
163
+ },
164
+ )
165
+ audio_column_name: str = field(
166
+ default="audio",
167
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
168
+ )
169
+ text_column_name: str = field(
170
+ default="text",
171
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
172
+ )
173
+ speaker_id_column_name: str = field(
174
+ default=None,
175
+ metadata={
176
+ "help": """If set, corresponds to the name of the speaker id column containing the speaker ids.
177
+ If `override_speaker_embeddings=False`:
178
+ it assumes that speakers are indexed from 0 to `num_speakers-1`.
179
+ `num_speakers` and `speaker_embedding_size` have to be set in the model config.
180
+
181
+ If `override_speaker_embeddings=True`:
182
+ It will use this column to compute how many speakers there are.
183
+
184
+ Defaults to None, i.e it is not used by default."""
185
+ },
186
+ )
187
+ filter_on_speaker_id: int = field(
188
+ default=None,
189
+ metadata={
190
+ "help": (
191
+ "If `speaker_id_column_name` and `filter_on_speaker_id` are set, will filter the dataset to keep a single speaker_id (`filter_on_speaker_id`) "
192
+ )
193
+ },
194
+ )
195
+
196
+ max_tokens_length: float = field(
197
+ default=450,
198
+ metadata={
199
+ "help": ("Truncate audio files with a transcription that are longer than `max_tokens_length` tokens")
200
+ },
201
+ )
202
+ max_duration_in_seconds: float = field(
203
+ default=20.0,
204
+ metadata={
205
+ "help": (
206
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
207
+ " 'max_duration_in_seconds`"
208
+ )
209
+ },
210
+ )
211
+ min_duration_in_seconds: float = field(
212
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
213
+ )
214
+ preprocessing_only: bool = field(
215
+ default=False,
216
+ metadata={
217
+ "help": (
218
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
219
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
220
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
221
+ " can consequently be loaded in distributed training"
222
+ )
223
+ },
224
+ )
225
+ train_split_name: str = field(
226
+ default="train",
227
+ metadata={
228
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
229
+ },
230
+ )
231
+ eval_split_name: str = field(
232
+ default="test",
233
+ metadata={
234
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
235
+ },
236
+ )
237
+ do_lower_case: bool = field(
238
+ default=False,
239
+ metadata={"help": "Whether the input text should be lower cased."},
240
+ )
241
+ do_normalize: bool = field(
242
+ default=False,
243
+ metadata={"help": "Whether the input waveform should be normalized."},
244
+ )
245
+ full_generation_sample_text: str = field(
246
+ default="This is a test, let's see what comes out of this.",
247
+ metadata={
248
+ "help": (
249
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
250
+ "only. For English speech recognition, it should be set to `None`."
251
+ )
252
+ },
253
+ )
254
+ uroman_path: str = field(
255
+ default=None,
256
+ metadata={
257
+ "help": (
258
+ "Absolute path to the uroman package. To use if your model requires `uroman`."
259
+ "An easy way to check it is to go on your model card and manually check `is_uroman` in the `tokenizer_config.json,"
260
+ "e.g the French checkpoint doesn't need it: https://huggingface.co/facebook/mms-tts-fra/blob/main/tokenizer_config.json#L4"
261
+ )
262
+ },
263
+ )
264
+
265
+ #.............................................................................................
VitsModelSplit/DATA ADDED
@@ -0,0 +1 @@
 
 
1
+
VitsModelSplit/FeaturesCollectionDataset_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
VitsModelSplit/PosteriorDecoderModel.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from transformers import set_seed
8
+ import wandb
9
+ import logging
10
+ import copy
11
+
12
+ from .vits_config import VitsConfig, VitsPreTrainedModel
13
+ from .feature_extraction import VitsFeatureExtractor
14
+ from .vits_output import PosteriorDecoderModelOutput
15
+ from .dataset_features_collector import FeaturesCollectionDataset
16
+ from .posterior_encoder import VitsPosteriorEncoder
17
+ from .decoder import VitsHifiGan
18
+
19
+ class PosteriorDecoderModel(torch.nn.Module):
20
+
21
+ def __init__(self, config,posterior_encoder,decoder,device=None):
22
+ super().__init__()
23
+
24
+ if device:
25
+ self.device = device
26
+ else:
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ self.config = copy.deepcopy(config)
30
+ self.posterior_encoder = copy.deepcopy(posterior_encoder)
31
+ self.decoder = copy.deepcopy(decoder)
32
+
33
+ if config.num_speakers > 1:
34
+ self.embed_speaker = nn.Embedding(config.num_speakers,
35
+ config.speaker_embedding_size
36
+ )
37
+ self.sampling_rate = config.sampling_rate
38
+ self.speaking_rate = config.speaking_rate
39
+ self.noise_scale = config.noise_scale
40
+ self.noise_scale_duration = config.noise_scale_duration
41
+ self.segment_size = self.config.segment_size // self.config.hop_length
42
+
43
+ self.to(self.device)
44
+
45
+
46
+
47
+ #....................................
48
+
49
+ def slice_segments(self,hidden_states, ids_str, segment_size=4):
50
+
51
+ batch_size, channels, _ = hidden_states.shape
52
+ # 1d tensor containing the indices to keep
53
+ indices = torch.arange(segment_size).to(ids_str.device)
54
+ # extend the indices to match the shape of hidden_states
55
+ indices = indices.view(1, 1, -1).expand(batch_size, channels, -1)
56
+ # offset indices with ids_str
57
+ indices = indices + ids_str.view(-1, 1, 1)
58
+ # gather indices
59
+ output = torch.gather(hidden_states, dim=2, index=indices)
60
+
61
+ return output
62
+
63
+ #....................................
64
+
65
+ def rand_slice_segments(self,hidden_states, sample_lengths=None, segment_size=4):
66
+ batch_size, _, seq_len = hidden_states.size()
67
+ if sample_lengths is None:
68
+ sample_lengths = seq_len
69
+ ids_str_max = sample_lengths - segment_size + 1
70
+ ids_str = (torch.rand([batch_size]).to(device=hidden_states.device) * ids_str_max).to(dtype=torch.long)
71
+ ret = self.slice_segments(hidden_states, ids_str, segment_size)
72
+
73
+ return ret, ids_str
74
+
75
+ #....................................
76
+
77
+ def forward(
78
+ self,
79
+ labels: Optional[torch.FloatTensor] = None,
80
+ labels_attention_mask: Optional[torch.Tensor] = None,
81
+ speaker_id: Optional[int] = None,
82
+ return_dict: Optional[bool] = True,
83
+ ) :
84
+
85
+ if self.config.num_speakers > 1 and speaker_id is not None:
86
+ if isinstance(speaker_id, int):
87
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
88
+ elif isinstance(speaker_id, (list, tuple, np.ndarray)):
89
+ speaker_id = torch.tensor(speaker_id, device=self.device)
90
+
91
+ if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item():
92
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
93
+
94
+ if not (len(speaker_id) == 1 or len(speaker_id == len(labels))):
95
+ raise ValueError(
96
+ f"You passed {len(speaker_id)} `speaker_id` but you should either pass one speaker id or `batch_size` `speaker_id`."
97
+ )
98
+
99
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
100
+ else:
101
+ speaker_embeddings = None
102
+
103
+
104
+ if labels_attention_mask is not None:
105
+ labels_padding_mask = labels_attention_mask.unsqueeze(1).float()
106
+ else:
107
+ labels_attention_mask = torch.ones((labels.shape[0], labels.shape[2])).float().to(self.device)
108
+ labels_padding_mask = labels_attention_mask.unsqueeze(1)
109
+
110
+
111
+ posterior_latents, posterior_means, posterior_log_variances = self.posterior_encoder(
112
+ labels, labels_padding_mask, speaker_embeddings
113
+ )
114
+
115
+ label_lengths = labels_attention_mask.sum(dim=1)
116
+ latents_slice, ids_slice = self.rand_slice_segments(posterior_latents,
117
+ label_lengths,
118
+ segment_size=self.segment_size
119
+ )
120
+
121
+ waveform = self.decoder(latents_slice, speaker_embeddings)
122
+
123
+ if not return_dict:
124
+ outputs = (
125
+ labels_padding_mask,
126
+ posterior_latents,
127
+ posterior_means,
128
+ posterior_log_variances,
129
+ latents_slice,
130
+ ids_slice,
131
+ waveform,
132
+ )
133
+ return outputs
134
+
135
+ return PosteriorDecoderModelOutput(
136
+ labels_padding_mask = labels_padding_mask,
137
+ posterior_latents = posterior_latents,
138
+ posterior_means = posterior_means,
139
+ posterior_log_variances = posterior_log_variances,
140
+ latents_slice = latents_slice,
141
+ ids_slice = ids_slice,
142
+ waveform = waveform,
143
+ )
144
+
145
+
146
+
147
+ #....................................
148
+
149
+ def trainer(self,
150
+ train_dataset_dir = None,
151
+ eval_dataset_dir = None,
152
+ full_generation_dir = None,
153
+ feature_extractor = VitsFeatureExtractor(),
154
+ training_args = None,
155
+ full_generation_sample_index= 0,
156
+ project_name = "Posterior_Decoder_Finetuning",
157
+ wandbKey = "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79",
158
+
159
+
160
+ ):
161
+
162
+ os.makedirs(training_args.output_dir,exist_ok=True)
163
+ logger = logging.getLogger(f"{__name__} Training")
164
+ log_level = training_args.get_process_log_level()
165
+ logger.setLevel(log_level)
166
+
167
+ wandb.login(key= wandbKey)
168
+ wandb.init(project= project_name,config = training_args.to_dict())
169
+
170
+
171
+ set_seed(training_args.seed)
172
+ # Apply Weight Norm Decoder
173
+ self.decoder.apply_weight_norm()
174
+ # Save Config
175
+ self.config.save_pretrained(training_args.output_dir)
176
+
177
+ train_dataset = FeaturesCollectionDataset(dataset_dir = train_dataset_dir,
178
+ device = self.device
179
+ )
180
+
181
+ eval_dataset = None
182
+ if training_args.do_eval:
183
+ eval_dataset = FeaturesCollectionDataset(dataset_dir = eval_dataset_dir,
184
+ device = self.device
185
+ )
186
+
187
+ full_generation_dataset = FeaturesCollectionDataset(dataset_dir = full_generation_dir,
188
+ device = self.device
189
+ )
190
+ self.full_generation_sample = full_generation_dataset[full_generation_sample_index]
191
+
192
+ # init optimizer, lr_scheduler
193
+
194
+ optimizer = torch.optim.AdamW(
195
+ self.parameters(),
196
+ training_args.learning_rate,
197
+ betas=[training_args.adam_beta1, training_args.adam_beta2],
198
+ eps=training_args.adam_epsilon,
199
+ )
200
+
201
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
202
+ optimizer, gamma=training_args.lr_decay, last_epoch=-1
203
+ )
204
+
205
+
206
+ logger.info("***** Running training *****")
207
+ logger.info(f" Num Epochs = {training_args.num_train_epochs}")
208
+
209
+
210
+ #.......................loop training............................
211
+
212
+ global_step = 0
213
+
214
+ for epoch in range(training_args.num_train_epochs):
215
+ train_losses_sum = 0
216
+ lr_scheduler.step()
217
+
218
+ for step, batch in enumerate(train_dataset):
219
+
220
+ # forward through model
221
+ outputs = self.forward(
222
+ labels=batch["labels"],
223
+ labels_attention_mask=batch["labels_attention_mask"],
224
+ speaker_id=batch["speaker_id"]
225
+ )
226
+
227
+ mel_scaled_labels = batch["mel_scaled_input_features"]
228
+ mel_scaled_target = self.slice_segments(mel_scaled_labels, outputs.ids_slice,self.segment_size)
229
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(outputs.waveform.squeeze(1))[1]
230
+
231
+ target_waveform = batch["waveform"].transpose(1, 2)
232
+ target_waveform = self.slice_segments(
233
+ target_waveform,
234
+ outputs.ids_slice * feature_extractor.hop_length,
235
+ self.config.segment_size
236
+ )
237
+
238
+
239
+ # backpropagate
240
+
241
+ loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
242
+ loss = loss_mel.detach().item()
243
+ train_losses_sum = train_losses_sum + loss
244
+ loss_mel.backward()
245
+ optimizer.step()
246
+ optimizer.zero_grad()
247
+
248
+ print(f"TRAINIG - batch {step}, waveform {(batch['waveform'].shape)}, step_loss_mel {loss}, lr {lr_scheduler.get_last_lr()[0]}... ")
249
+ global_step +=1
250
+
251
+ # validation
252
+
253
+ do_eval = training_args.do_eval and (global_step % training_args.eval_steps == 0)
254
+ if do_eval:
255
+ logger.info("Running validation... ")
256
+ eval_losses_sum = 0
257
+ for step, batch in enumerate(eval_dataset):
258
+
259
+ with torch.no_grad():
260
+ outputs = self.forward(
261
+ labels=batch["labels"],
262
+ labels_attention_mask=batch["labels_attention_mask"],
263
+ speaker_id=batch["speaker_id"]
264
+ )
265
+
266
+ mel_scaled_labels = batch["mel_scaled_input_features"]
267
+ mel_scaled_target = self.slice_segments(mel_scaled_labels, outputs.ids_slice,self.segment_size)
268
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(outputs.waveform.squeeze(1))[1]
269
+ loss = loss_mel.detach().item()
270
+ eval_losses_sum +=loss
271
+ loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
272
+ print(f"VALIDATION - batch {step}, waveform {(batch['waveform'].shape)}, step_loss_mel {loss} ... ")
273
+
274
+
275
+
276
+ with torch.no_grad():
277
+ full_generation_sample = self.full_generation_sample
278
+ full_generation =self.forward(
279
+ labels=full_generation_sample["labels"],
280
+ labels_attention_mask=full_generation_sample["labels_attention_mask"],
281
+ speaker_id=full_generation_sample["speaker_id"]
282
+ )
283
+
284
+ full_generation_waveform = full_generation.waveform.cpu().numpy()
285
+
286
+ wandb.log({
287
+ "eval_losses": eval_losses_sum,
288
+ "full generations samples": [
289
+ wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=self.sampling_rate)
290
+ for w in full_generation_waveform],})
291
+
292
+ wandb.log({"train_losses":train_losses_sum})
293
+
294
+ # add weight norms
295
+ self.decoder.remove_weight_norm()
296
+
297
+
298
+ torch.save(self.posterior_encoder.state_dict(), os.path.join(training_args.output_dir,"posterior_encoder.pt"))
299
+ torch.save(self.decoder.state_dict(), os.path.join(training_args.output_dir,"decoder.pt"))
300
+
301
+
302
+
303
+ logger.info("Running final full generations samples... ")
304
+
305
+
306
+ with torch.no_grad():
307
+
308
+ full_generation_sample = self.full_generation_sample
309
+ full_generation = self.forward(
310
+ labels=full_generation_sample["labels"],
311
+ labels_attention_mask=full_generation_sample["labels_attention_mask"],
312
+ speaker_id=full_generation_sample["speaker_id"]
313
+ )
314
+
315
+ full_generation_waveform = full_generation.waveform.cpu().numpy()
316
+
317
+ wandb.log({"eval_losses": eval_losses_sum,
318
+ "full generations samples": [
319
+ wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}",
320
+ sample_rate=self.sampling_rate) for w in full_generation_waveform],
321
+ })
322
+
323
+
324
+ logger.info("***** Training / Inference Done *****")
325
+
326
+ #....................................
327
+
328
+
329
+
330
+
331
+ #....................................
VitsModelSplit/PosteriorDecoderModel_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
VitsModelSplit/Trainer.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import numpy as np
5
+ import wandb
6
+ from transformers import VitsModel
7
+ import math
8
+ import torch
9
+ from accelerate.utils import ProjectConfiguration, is_wandb_available, set_seed
10
+ from accelerate import Accelerator, DistributedDataParallelKwargs
11
+ from transformers.utils import send_example_telemetry
12
+ import logging
13
+ import sys
14
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
15
+ from transformers.trainer_pt_utils import LengthGroupedSampler
16
+ from transformers.optimization import get_scheduler
17
+
18
+
19
+ from .data_collator import DataCollatorTTSWithPadding
20
+ from .discriminator import VitsDiscriminator
21
+ from .feature_extraction import VitsFeatureExtractor
22
+ from .plot import plot_alignment_to_numpy, plot_spectrogram_to_numpy
23
+
24
+ #.............................................
25
+
26
+ if is_wandb_available():
27
+ import wandb
28
+
29
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
30
+ logger = logging.getLogger(__name__)
31
+ #.............................................
32
+
33
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
34
+ loss = 0
35
+ real_losses = 0
36
+ generated_losses = 0
37
+ for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs):
38
+ real_loss = torch.mean((1 - disc_real) ** 2)
39
+ generated_loss = torch.mean(disc_generated**2)
40
+ loss += real_loss + generated_loss
41
+ real_losses += real_loss
42
+ generated_losses += generated_loss
43
+
44
+ return loss, real_losses, generated_losses
45
+
46
+
47
+ def feature_loss(feature_maps_real, feature_maps_generated):
48
+ loss = 0
49
+ for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated):
50
+ for real, generated in zip(feature_map_real, feature_map_generated):
51
+ real = real.detach()
52
+ loss += torch.mean(torch.abs(real - generated))
53
+
54
+ return loss * 2
55
+
56
+
57
+ def generator_loss(disc_outputs):
58
+ total_loss = 0
59
+ gen_losses = []
60
+ for disc_output in disc_outputs:
61
+ disc_output = disc_output
62
+ loss = torch.mean((1 - disc_output) ** 2)
63
+ gen_losses.append(loss)
64
+ total_loss += loss
65
+
66
+ return total_loss, gen_losses
67
+
68
+
69
+ def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask):
70
+ """
71
+ z_p, logs_q: [b, h, t_t]
72
+ prior_means, prior_log_variance: [b, h, t_t]
73
+ """
74
+
75
+ kl = prior_log_variance - posterior_log_variance - 0.5
76
+ kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance)
77
+ kl = torch.sum(kl * labels_mask)
78
+ loss = kl / torch.sum(labels_mask)
79
+ return loss
80
+
81
+
82
+ def log_on_trackers(
83
+ trackers,
84
+ generated_audio,
85
+ generated_attn,
86
+ generated_spec,
87
+ target_spec,
88
+ full_generation_waveform,
89
+ epoch,
90
+ sampling_rate,
91
+ ):
92
+ max_num_samples = min(len(generated_audio), 50)
93
+ generated_audio = generated_audio[:max_num_samples]
94
+ generated_attn = generated_attn[:max_num_samples]
95
+ generated_spec = generated_spec[:max_num_samples]
96
+ target_spec = target_spec[:max_num_samples]
97
+
98
+ for tracker in trackers:
99
+ if tracker.name == "tensorboard":
100
+ for cpt, audio in enumerate(generated_audio):
101
+ tracker.writer.add_audio(f"train_step_audio_{cpt}", audio[None, :], epoch, sample_rate=sampling_rate)
102
+
103
+ for cpt, audio in enumerate(full_generation_waveform):
104
+ tracker.writer.add_audio(
105
+ f"full_generation_sample{cpt}", audio[None, :], epoch, sample_rate=sampling_rate
106
+ )
107
+
108
+ tracker.writer.add_images("alignements", np.stack(generated_attn), dataformats="NHWC")
109
+ tracker.writer.add_images("spectrogram", np.stack(generated_spec), dataformats="NHWC")
110
+ tracker.writer.add_images("target spectrogram", np.stack(target_spec), dataformats="NHWC")
111
+ elif tracker.name == "wandb":
112
+ # wandb can only loads 100 audios per step
113
+ tracker.log(
114
+ {
115
+ "alignments": [wandb.Image(attn, caption=f"Audio epoch {epoch}") for attn in generated_attn],
116
+ "spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in generated_spec],
117
+ "target spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in target_spec],
118
+ "train generated audio": [
119
+ wandb.Audio(
120
+ audio[0],
121
+ caption=f"Audio during train step epoch {epoch}",
122
+ sample_rate=sampling_rate,
123
+ )
124
+ for audio in generated_audio
125
+ ],
126
+ "full generations samples": [
127
+ wandb.Audio(w, caption=f"Full generation sample {epoch}", sample_rate=sampling_rate)
128
+ for w in full_generation_waveform
129
+ ],
130
+ }
131
+ )
132
+ else:
133
+ logger.warn(f"audio logging not implemented for {tracker.name}")
134
+
135
+
136
+ def compute_val_metrics_and_losses(
137
+ val_losses,
138
+ accelerator,
139
+ model_outputs,
140
+ mel_scaled_generation,
141
+ mel_scaled_target,
142
+ batch_size,
143
+ compute_clap_similarity=False,
144
+ ):
145
+ loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
146
+ loss_kl = kl_loss(
147
+ model_outputs.prior_latents,
148
+ model_outputs.posterior_log_variances,
149
+ model_outputs.prior_means,
150
+ model_outputs.prior_log_variances,
151
+ model_outputs.labels_padding_mask,
152
+ )
153
+
154
+ losses_mel_kl = loss_mel + loss_kl
155
+
156
+ losses = torch.stack([loss_mel, loss_kl, losses_mel_kl])
157
+ losses = accelerator.gather(losses.repeat(batch_size, 1)).mean(0)
158
+
159
+ for key, loss in zip(["val_loss_mel", "val_loss_kl", "val_loss_mel_kl"], losses):
160
+ val_losses[key] = val_losses.get(key, 0) + loss.item()
161
+
162
+ return val_losses
163
+
164
+
165
+ #.............................................
166
+
167
+ def vits_trainin(
168
+ model,
169
+ tokenizer,
170
+ model_args,
171
+ data_args,
172
+ training_args,
173
+ train_dataset,
174
+ eval_dataset,
175
+
176
+ ):
177
+
178
+
179
+
180
+
181
+ send_example_telemetry("run_vits_finetuning", model_args, data_args)
182
+
183
+ logging.basicConfig(
184
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
185
+ datefmt="%m/%d/%Y %H:%M:%S",
186
+ handlers=[logging.StreamHandler(sys.stdout)],
187
+ )
188
+ log_level = training_args.get_process_log_level()
189
+ logger.setLevel(log_level)
190
+ # datasets.utils.logging.set_verbosity(log_level)
191
+ # transformers.utils.logging.set_verbosity(log_level)
192
+ # transformers.utils.logging.enable_default_handler()
193
+ # transformers.utils.logging.enable_explicit_format()
194
+ # # logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
195
+ # if is_main_process(training_args.local_rank):
196
+ # transformers.utils.logging.set_verbosity_info()
197
+
198
+
199
+
200
+
201
+ set_seed(training_args.seed)
202
+
203
+
204
+
205
+ config = model.config
206
+ feature_extractor = VitsFeatureExtractor()
207
+
208
+ forward_attention_mask = True
209
+
210
+
211
+ with training_args.main_process_first(desc="apply_weight_norm"):
212
+ # apply weight norms
213
+ model.decoder.apply_weight_norm()
214
+ for flow in model.flow.flows:
215
+ torch.nn.utils.weight_norm(flow.conv_pre)
216
+ torch.nn.utils.weight_norm(flow.conv_post)
217
+
218
+
219
+
220
+ with training_args.main_process_first():
221
+ # only the main process saves them
222
+ if is_main_process(training_args.local_rank):
223
+ # save feature extractor, tokenizer and config
224
+ feature_extractor.save_pretrained(training_args.output_dir)
225
+ tokenizer.save_pretrained(training_args.output_dir)
226
+ config.save_pretrained(training_args.output_dir)
227
+
228
+
229
+ data_collator = DataCollatorTTSWithPadding(
230
+ tokenizer=tokenizer,
231
+ feature_extractor=feature_extractor,
232
+ forward_attention_mask=forward_attention_mask,
233
+ )
234
+
235
+ with training_args.main_process_first():
236
+ input_str = data_args.full_generation_sample_text
237
+ full_generation_sample = tokenizer(input_str, return_tensors="pt")
238
+
239
+
240
+ project_name = data_args.project_name
241
+ logging_dir = os.path.join(training_args.output_dir, training_args.logging_dir)
242
+ accelerator_project_config = ProjectConfiguration(project_dir=training_args.output_dir, logging_dir=logging_dir)
243
+
244
+ accelerator = Accelerator(
245
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
246
+ log_with=training_args.report_to,
247
+ project_config=accelerator_project_config,
248
+ kwargs_handlers=[ddp_kwargs],
249
+ )
250
+
251
+ per_device_train_batch_size = (
252
+ training_args.per_device_train_batch_size if training_args.per_device_train_batch_size else 1
253
+ )
254
+ total_batch_size = (
255
+ per_device_train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps
256
+ )
257
+
258
+ num_speakers = model.config.num_speakers
259
+ if training_args.gradient_checkpointing:
260
+ model.gradient_checkpointing_enable()
261
+
262
+
263
+
264
+ train_dataloader = None
265
+ if training_args.do_train:
266
+ sampler = (
267
+ LengthGroupedSampler(
268
+ batch_size=per_device_train_batch_size,
269
+ dataset=train_dataset,
270
+ lengths=train_dataset["tokens_input_length"],
271
+ )
272
+ if training_args.group_by_length
273
+ else None
274
+ )
275
+ train_dataloader = torch.utils.data.DataLoader(
276
+ train_dataset,
277
+ shuffle=False,#not training_args.group_by_length,
278
+ collate_fn=data_collator,
279
+ batch_size=training_args.per_device_train_batch_size,
280
+ num_workers=training_args.dataloader_num_workers,
281
+ sampler=sampler,
282
+ )
283
+
284
+ eval_dataloader = None
285
+ if training_args.do_eval:
286
+ eval_sampler = (
287
+ LengthGroupedSampler(
288
+ batch_size=training_args.per_device_eval_batch_size,
289
+ dataset=eval_dataset,
290
+ lengths=eval_dataset["tokens_input_length"],
291
+ )
292
+ if training_args.group_by_length
293
+ else None
294
+ )
295
+
296
+ eval_dataloader = torch.utils.data.DataLoader(
297
+ eval_dataset,
298
+ shuffle=False,
299
+ collate_fn=data_collator,
300
+ batch_size=training_args.per_device_eval_batch_size,
301
+ num_workers=training_args.dataloader_num_workers,
302
+ sampler=eval_sampler,
303
+ )
304
+
305
+ model_segment_size = model.segment_size
306
+ config_segment_size = model.config.segment_size
307
+ sampling_rate = model.config.sampling_rate
308
+
309
+ # Scheduler and math around the number of training steps.
310
+ overrode_max_train_steps = False
311
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
312
+ if training_args.max_steps == -1:
313
+ training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch
314
+ overrode_max_train_steps = True
315
+
316
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
317
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
318
+ if overrode_max_train_steps:
319
+ training_args.max_steps = int(training_args.num_train_epochs * num_update_steps_per_epoch)
320
+ # Afterwards we recalculate our number of training epochs
321
+ training_args.num_train_epochs = math.ceil(training_args.max_steps / num_update_steps_per_epoch)
322
+
323
+ # hack to be able to train on multiple device
324
+ with tempfile.TemporaryDirectory() as tmpdirname:
325
+ model.discriminator.save_pretrained(tmpdirname)
326
+ discriminator = VitsDiscriminator.from_pretrained(tmpdirname)
327
+ for disc in discriminator.discriminators:
328
+ disc.apply_weight_norm()
329
+ del model.discriminator
330
+
331
+ # init gen_optimizer, gen_lr_scheduler, disc_optimizer, dics_lr_scheduler
332
+ gen_optimizer = torch.optim.AdamW(
333
+ model.parameters(),
334
+ training_args.learning_rate,
335
+ betas=[training_args.adam_beta1, training_args.adam_beta2],
336
+ eps=training_args.adam_epsilon,
337
+ )
338
+
339
+ disc_optimizer = torch.optim.AdamW(
340
+ discriminator.parameters(),
341
+ training_args.learning_rate,
342
+ betas=[training_args.adam_beta1, training_args.adam_beta2],
343
+ eps=training_args.adam_epsilon,
344
+ )
345
+
346
+ num_warmups_steps = training_args.get_warmup_steps(training_args.num_train_epochs * accelerator.num_processes)
347
+ num_training_steps = training_args.num_train_epochs * accelerator.num_processes
348
+
349
+ gen_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
350
+ gen_optimizer, gamma=training_args.lr_decay, last_epoch=-1
351
+ )
352
+ disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
353
+ disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
354
+ )
355
+
356
+
357
+ # Prepare everything with our `accelerator`.
358
+ (
359
+ model,
360
+ discriminator,
361
+ gen_optimizer,
362
+ gen_lr_scheduler,
363
+ disc_optimizer,
364
+ disc_lr_scheduler,
365
+ train_dataloader,
366
+ eval_dataloader,
367
+ ) = accelerator.prepare(
368
+ model,
369
+ discriminator,
370
+ gen_optimizer,
371
+ gen_lr_scheduler,
372
+ disc_optimizer,
373
+ disc_lr_scheduler,
374
+ train_dataloader,
375
+ eval_dataloader,
376
+ )
377
+
378
+
379
+ # We need to initialize the trackers we use, and also store our configuration.
380
+ # The trackers initializes automatically on the main process.
381
+ if accelerator.is_main_process:
382
+ tracker_config = training_args.to_sanitized_dict()
383
+ accelerator.init_trackers(project_name, tracker_config)
384
+
385
+
386
+
387
+ # Train!
388
+ logger.info("***** Running training *****")
389
+ logger.info(f" Num examples = {len(train_dataset)}")
390
+ logger.info(f" Num Epochs = {training_args.num_train_epochs}")
391
+ logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
392
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
393
+ logger.info(f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}")
394
+ logger.info(f" Total optimization steps = {training_args.max_steps}")
395
+ global_step = 0
396
+ first_epoch = 0
397
+
398
+
399
+
400
+ # Potentially load in the weights and states from a previous save
401
+ if training_args.resume_from_checkpoint:
402
+ if training_args.resume_from_checkpoint != "latest":
403
+ path = os.path.basename(training_args.resume_from_checkpoint)
404
+ else:
405
+ # Get the most recent checkpoint
406
+ dirs = os.listdir(training_args.output_dir)
407
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
408
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
409
+ path = dirs[-1] if len(dirs) > 0 else None
410
+
411
+ if path is None:
412
+ accelerator.print(
413
+ f"Checkpoint '{training_args.resume_from_checkpoint}' does not exist. Starting a new training run."
414
+ )
415
+ training_args.resume_from_checkpoint = None
416
+ initial_global_step = 0
417
+ else:
418
+ accelerator.print(f"Resuming from checkpoint {path}")
419
+ accelerator.load_state(os.path.join(training_args.output_dir, path))
420
+ global_step = int(path.split("-")[1])
421
+
422
+ initial_global_step = global_step
423
+ first_epoch = global_step // num_update_steps_per_epoch
424
+
425
+ else:
426
+ initial_global_step = 0
427
+
428
+
429
+
430
+ #.......................loop training............................
431
+
432
+ for epoch in range(first_epoch, training_args.num_train_epochs):
433
+ # keep track of train losses
434
+ train_losses = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
435
+
436
+ disc_lr_scheduler.step()
437
+ gen_lr_scheduler.step()
438
+
439
+ for step, batch in enumerate(train_dataloader):
440
+ print(f"TRAINIG - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... ")
441
+ with accelerator.accumulate(model, discriminator):
442
+ # forward through model
443
+ model_outputs = model(
444
+ input_ids=batch["input_ids"],
445
+ attention_mask=batch["attention_mask"],
446
+ labels=batch["labels"],
447
+ labels_attention_mask=batch["labels_attention_mask"],
448
+ speaker_id=batch["speaker_id"],
449
+ encoder_output = batch['text_encoder_output'],
450
+
451
+ return_dict=True,
452
+ monotonic_alignment_function=None,
453
+ )
454
+
455
+ mel_scaled_labels = batch["mel_scaled_input_features"]
456
+ mel_scaled_target = model.slice_segments(mel_scaled_labels, model_outputs.ids_slice, model_segment_size)
457
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
458
+ model_outputs.waveform.squeeze(1)
459
+ )[1]
460
+
461
+ target_waveform = batch["waveform"].transpose(1, 2)
462
+ target_waveform = model.slice_segments(
463
+ target_waveform, model_outputs.ids_slice * feature_extractor.hop_length, config_segment_size
464
+ )
465
+
466
+ # -----------------------
467
+ # Train Discriminator
468
+ # -----------------------
469
+
470
+ discriminator_target, _ = discriminator(target_waveform)
471
+ discriminator_candidate, _ = discriminator(model_outputs.waveform.detach())
472
+
473
+ loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
474
+ discriminator_target, discriminator_candidate
475
+ )
476
+
477
+ # backpropagate
478
+ accelerator.backward(loss_disc * training_args.weight_disc)
479
+ if accelerator.sync_gradients:
480
+ accelerator.clip_grad_norm_(discriminator.parameters(), training_args.max_grad_norm)
481
+ disc_optimizer.step()
482
+ if not training_args.do_step_schedule_per_epoch:
483
+ disc_lr_scheduler.step()
484
+ disc_optimizer.zero_grad()
485
+
486
+ # -----------------------
487
+ # Train Generator
488
+ # -----------------------
489
+
490
+ _, fmaps_target = discriminator(target_waveform)
491
+ discriminator_candidate, fmaps_candidate = discriminator(model_outputs.waveform)
492
+
493
+ loss_duration = torch.sum(model_outputs.log_duration)
494
+ loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
495
+ loss_kl = kl_loss(
496
+ model_outputs.prior_latents,
497
+ model_outputs.posterior_log_variances,
498
+ model_outputs.prior_means,
499
+ model_outputs.prior_log_variances,
500
+ model_outputs.labels_padding_mask,
501
+ )
502
+ loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
503
+ loss_gen, losses_gen = generator_loss(discriminator_candidate)
504
+
505
+ total_generator_loss = (
506
+ loss_duration * training_args.weight_duration
507
+ + loss_mel * training_args.weight_mel
508
+ + loss_kl * training_args.weight_kl
509
+ + loss_fmaps * training_args.weight_fmaps
510
+ + loss_gen * training_args.weight_gen
511
+ )
512
+
513
+ # backpropagate
514
+ accelerator.backward(total_generator_loss)
515
+ if accelerator.sync_gradients:
516
+ accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
517
+ gen_optimizer.step()
518
+ if not training_args.do_step_schedule_per_epoch:
519
+ gen_lr_scheduler.step()
520
+ gen_optimizer.zero_grad()
521
+
522
+ # update and gather losses
523
+ losses = torch.stack(
524
+ [
525
+ # for fair comparison, don't use weighted loss
526
+ loss_duration + loss_mel + loss_kl + loss_fmaps + loss_gen,
527
+ loss_duration,
528
+ loss_mel,
529
+ loss_kl,
530
+ loss_fmaps,
531
+ loss_gen,
532
+ loss_disc,
533
+ loss_real_disc,
534
+ loss_fake_disc,
535
+ ]
536
+ )
537
+ losses = accelerator.gather(losses.repeat(per_device_train_batch_size, 1)).mean(0)
538
+
539
+ train_losses = [
540
+ l + losses[i].item() / training_args.gradient_accumulation_steps
541
+ for (i, l) in enumerate(train_losses)
542
+ ]
543
+
544
+ # Checks if the accelerator has performed an optimization step behind the scenes
545
+ if accelerator.sync_gradients:
546
+ (
547
+ train_summed_losses,
548
+ train_loss_duration,
549
+ train_loss_mel,
550
+ train_loss_kl,
551
+ train_loss_fmaps,
552
+ train_loss_gen,
553
+ train_loss_disc,
554
+ train_loss_real_disc,
555
+ train_loss_fake_disc,
556
+ ) = train_losses
557
+
558
+ global_step += 1
559
+ accelerator.log(
560
+ {
561
+ "train_summed_losses": train_summed_losses,
562
+ "train_loss_disc": train_loss_disc,
563
+ "train_loss_real_disc": train_loss_real_disc,
564
+ "train_loss_fake_disc": train_loss_fake_disc,
565
+ "train_loss_duration": train_loss_duration,
566
+ "train_loss_mel": train_loss_mel,
567
+ "train_loss_kl": train_loss_kl,
568
+ "train_loss_fmaps": train_loss_fmaps,
569
+ "train_loss_gen": train_loss_gen,
570
+ "lr": disc_lr_scheduler.get_last_lr()[0],
571
+ },
572
+ step=global_step,
573
+ )
574
+ train_losses = [0.0 for _ in train_losses]
575
+
576
+ if global_step % training_args.save_steps == 0:
577
+ if accelerator.is_main_process:
578
+ # _before_ saving state, check if this save would set us over the `save_total_limit`
579
+ if training_args.save_total_limit is not None:
580
+ checkpoints = os.listdir(training_args.output_dir)
581
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
582
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
583
+
584
+ # before we save the new checkpoint, we need to have at _most_ `save_total_limit - 1` checkpoints
585
+ if len(checkpoints) >= training_args.save_total_limit:
586
+ num_to_remove = len(checkpoints) - training_args.save_total_limit + 1
587
+ removing_checkpoints = checkpoints[0:num_to_remove]
588
+
589
+ logger.info(
590
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
591
+ )
592
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
593
+
594
+ for removing_checkpoint in removing_checkpoints:
595
+ removing_checkpoint = os.path.join(training_args.output_dir, removing_checkpoint)
596
+ shutil.rmtree(removing_checkpoint)
597
+
598
+ save_path = os.path.join(training_args.output_dir, f"checkpoint-{global_step}")
599
+ accelerator.save_state(save_path)
600
+ logger.info(f"Saved state to {save_path}")
601
+
602
+ logs = {
603
+ "step_loss": total_generator_loss.detach().item(),
604
+ "lr": disc_lr_scheduler.get_last_lr()[0],
605
+ "step_loss_duration": loss_duration.detach().item(),
606
+ "step_loss_mel": loss_mel.detach().item(),
607
+ "step_loss_kl": loss_kl.detach().item(),
608
+ "step_loss_fmaps": loss_fmaps.detach().item(),
609
+ "step_loss_gen": loss_gen.detach().item(),
610
+ "step_loss_disc": loss_disc.detach().item(),
611
+ "step_loss_real_disc": loss_real_disc.detach().item(),
612
+ "step_loss_fake_disc": loss_fake_disc.detach().item(),
613
+ }
614
+
615
+
616
+ if global_step >= training_args.max_steps:
617
+ break
618
+
619
+ eval_steps = training_args.eval_steps if training_args.eval_steps else 1
620
+ do_eval = training_args.do_eval and (global_step % eval_steps == 0) and accelerator.sync_gradients
621
+
622
+ if do_eval:
623
+ logger.info("Running validation... ")
624
+ generated_audio = []
625
+ generated_attn = []
626
+ generated_spec = []
627
+ target_spec = []
628
+ val_losses = {}
629
+ for step, batch in enumerate(eval_dataloader):
630
+ print(
631
+ f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... "
632
+ )
633
+ with torch.no_grad():
634
+ model_outputs_train = model(
635
+ input_ids=batch["input_ids"],
636
+ attention_mask=batch["attention_mask"],
637
+ labels=batch["labels"],
638
+ labels_attention_mask=batch["labels_attention_mask"],
639
+ speaker_id=batch["speaker_id"],
640
+ encoder_output = batch['text_encoder_output'],
641
+
642
+ return_dict=True,
643
+ monotonic_alignment_function=None,
644
+ )
645
+
646
+ mel_scaled_labels = batch["mel_scaled_input_features"]
647
+ mel_scaled_target = model.slice_segments(
648
+ mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size
649
+ )
650
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
651
+ model_outputs_train.waveform.squeeze(1)
652
+ )[1]
653
+
654
+ val_losses = compute_val_metrics_and_losses(
655
+ val_losses,
656
+ accelerator,
657
+ model_outputs_train,
658
+ mel_scaled_generation,
659
+ mel_scaled_target,
660
+ per_device_train_batch_size,
661
+ compute_clap_similarity=False,
662
+ )
663
+
664
+ print(f"VALIDATION - batch {step}, process{accelerator.process_index}, PADDING AND GATHER... ")
665
+ specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0]
666
+ padded_attn, specs, target_specs = accelerator.pad_across_processes(
667
+ [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1
668
+ )
669
+ padded_attn, specs, target_specs = accelerator.pad_across_processes(
670
+ [padded_attn, specs, target_specs], dim=2
671
+ )
672
+
673
+ generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics(
674
+ [model_outputs_train.waveform, padded_attn, specs, target_specs]
675
+ )
676
+
677
+
678
+ if accelerator.is_main_process:
679
+ with torch.no_grad():
680
+ speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers)))
681
+ full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id)
682
+
683
+ generated_audio.append(generated_train_waveform.cpu())
684
+ generated_attn.append(padded_attn.cpu())
685
+ generated_spec.append(specs.cpu())
686
+ target_spec.append(target_specs.cpu())
687
+
688
+ logger.info("Validation inference done, now evaluating... ")
689
+ if accelerator.is_main_process:
690
+ generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch]
691
+ generated_attn = [
692
+ plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch
693
+ ]
694
+ generated_spec = [
695
+ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch
696
+ ]
697
+ target_spec = [
698
+ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch
699
+ ]
700
+ full_generation_waveform = full_generation.waveform.cpu().numpy()
701
+
702
+ accelerator.log(val_losses, step=global_step)
703
+
704
+ log_on_trackers(
705
+ accelerator.trackers,
706
+ generated_audio,
707
+ generated_attn,
708
+ generated_spec,
709
+ target_spec,
710
+ full_generation_waveform,
711
+ epoch,
712
+ sampling_rate,
713
+ )
714
+
715
+ logger.info("Validation finished... ")
716
+
717
+ accelerator.wait_for_everyone()
718
+
719
+ accelerator.wait_for_everyone()
720
+ if accelerator.is_main_process:
721
+ epoch = training_args.num_train_epochs if training_args.num_train_epochs else 1
722
+ eval_steps = training_args.eval_steps if training_args.eval_steps else 1
723
+
724
+ # Run a final round of inference.
725
+ do_eval = training_args.do_eval
726
+
727
+ if do_eval:
728
+ logger.info("Running final validation... ")
729
+ generated_audio = []
730
+ generated_attn = []
731
+ generated_spec = []
732
+ target_spec = []
733
+ val_losses = {}
734
+ for step, batch in enumerate(eval_dataloader):
735
+ print(
736
+ f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... "
737
+ )
738
+ with torch.no_grad():
739
+ model_outputs_train = model(
740
+ input_ids=batch["input_ids"],
741
+ attention_mask=batch["attention_mask"],
742
+ labels=batch["labels"],
743
+ labels_attention_mask=batch["labels_attention_mask"],
744
+ speaker_id=batch["speaker_id"],
745
+ encoder_output = batch['text_encoder_output'],
746
+
747
+ return_dict=True,
748
+ monotonic_alignment_function=None,
749
+ )
750
+
751
+ mel_scaled_labels = batch["mel_scaled_input_features"]
752
+ mel_scaled_target = model.slice_segments(
753
+ mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size
754
+ )
755
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
756
+ model_outputs_train.waveform.squeeze(1)
757
+ )[1]
758
+
759
+ val_losses = compute_val_metrics_and_losses(
760
+ val_losses,
761
+ accelerator,
762
+ model_outputs_train,
763
+ mel_scaled_generation,
764
+ mel_scaled_target,
765
+ per_device_train_batch_size,
766
+ compute_clap_similarity=False,
767
+ )
768
+ specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0]
769
+ padded_attn, specs, target_specs = accelerator.pad_across_processes(
770
+ [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1
771
+ )
772
+ padded_attn, specs, target_specs = accelerator.pad_across_processes(
773
+ [padded_attn, specs, target_specs], dim=2
774
+ )
775
+
776
+ generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics(
777
+ [model_outputs_train.waveform, padded_attn, specs, target_specs]
778
+ )
779
+
780
+ if accelerator.is_main_process:
781
+ with torch.no_grad():
782
+ speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers)))
783
+ full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id)
784
+
785
+ generated_audio.append(generated_train_waveform.cpu())
786
+ generated_attn.append(padded_attn.cpu())
787
+ generated_spec.append(specs.cpu())
788
+ target_spec.append(target_specs.cpu())
789
+
790
+ logger.info("Validation inference done, now evaluating... ")
791
+ if accelerator.is_main_process:
792
+ generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch]
793
+ generated_attn = [
794
+ plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch
795
+ ]
796
+ generated_spec = [
797
+ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch
798
+ ]
799
+ target_spec = [
800
+ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch
801
+ ]
802
+ full_generation_waveform = full_generation.waveform.cpu().numpy()
803
+
804
+ log_on_trackers(
805
+ accelerator.trackers,
806
+ generated_audio,
807
+ generated_attn,
808
+ generated_spec,
809
+ target_spec,
810
+ full_generation_waveform,
811
+ epoch,
812
+ sampling_rate,
813
+ )
814
+
815
+ accelerator.log(val_losses, step=global_step)
816
+ logger.info("Validation finished... ")
817
+
818
+ accelerator.wait_for_everyone()
819
+
820
+ # unwrap, save and push final model
821
+ model = accelerator.unwrap_model(model)
822
+ discriminator = accelerator.unwrap_model(discriminator)
823
+
824
+ model.discriminator = discriminator
825
+
826
+ # add weight norms
827
+ for disc in model.discriminator.discriminators:
828
+ disc.remove_weight_norm()
829
+ model.decoder.remove_weight_norm()
830
+ for flow in model.flow.flows:
831
+ torch.nn.utils.remove_weight_norm(flow.conv_pre)
832
+ torch.nn.utils.remove_weight_norm(flow.conv_post)
833
+
834
+ model.save_pretrained(training_args.output_dir)
835
+
836
+ if training_args.push_to_hub:
837
+ VitsModel.from_pretrained(training_args.output_dir).push_to_hub(training_args.hub_model_id)
838
+
839
+ accelerator.end_training()
840
+
841
+
842
+
843
+ logger.info("***** Training / Inference Done *****")
844
+
845
+
846
+
847
+
848
+ #...............................................................................
VitsModelSplit/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ # from . import vits_config,vits_model,vits_output,arguments,decoder,encoder,discriminator,duration_predictor,flow,posterior_encoder
3
+ # from . import PosteriorDecoderModel,plot,trainer
4
+ # from . import dataset_features_collector,feature_extraction
VitsModelSplit/data_collator.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union,List,Dict
2
+ import numpy as np
3
+ import torch
4
+ from dataclasses import dataclass
5
+ from transformers.feature_extraction_utils import BatchFeature
6
+
7
+ from .vits_output import VitsTextEncoderOutput
8
+ #.............................................
9
+
10
+
11
+ @dataclass
12
+ class DataCollatorTTSWithPadding:
13
+ """
14
+ Data collator that will dynamically pad the inputs received.
15
+ Args:
16
+ tokenizer ([`VitsTokenizer`])
17
+ The tokenizer used for processing the data.
18
+ feature_extractor ([`VitsFeatureExtractor`])
19
+ The tokenizer used for processing the data.
20
+ forward_attention_mask (`bool`)
21
+ Whether to return attention_mask.
22
+ """
23
+
24
+ tokenizer: Any
25
+ feature_extractor: Any
26
+ forward_attention_mask: bool
27
+
28
+ def pad_waveform(self, raw_speech):
29
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
30
+ if is_batched_numpy and len(raw_speech.shape) > 2:
31
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
32
+ is_batched = is_batched_numpy or (
33
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
34
+ )
35
+
36
+ if is_batched:
37
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
38
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
39
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
40
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
41
+ raw_speech = raw_speech.astype(np.float32)
42
+
43
+ # always return batch
44
+ if not is_batched:
45
+ raw_speech = [np.asarray([raw_speech]).T]
46
+
47
+ batched_speech = BatchFeature({"input_features": raw_speech})
48
+
49
+ # convert into correct format for padding
50
+
51
+ padded_inputs = self.feature_extractor.pad(
52
+ batched_speech,
53
+ padding=True,
54
+ return_attention_mask=False,
55
+ return_tensors="pt",
56
+ )["input_features"]
57
+
58
+ return padded_inputs
59
+
60
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
61
+ # split inputs and labels since they have to be of different lengths and need
62
+ # different padding methods
63
+
64
+ model_input_name = "input_ids"
65
+
66
+ input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features]
67
+
68
+ # pad input tokens
69
+ batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask)
70
+
71
+ # pad waveform
72
+ waveforms = [np.array(feature["waveform"]) for feature in features]
73
+ batch["waveform"] = self.pad_waveform(waveforms)
74
+
75
+ # pad spectrogram
76
+ label_features = [np.array(feature["labels"]) for feature in features]
77
+ labels_batch = self.feature_extractor.pad(
78
+ {"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True
79
+ )
80
+
81
+ labels = labels_batch["input_features"].transpose(1, 2)
82
+ batch["labels"] = labels
83
+ batch["labels_attention_mask"] = labels_batch["attention_mask"]
84
+
85
+ # pad mel spectrogram
86
+ mel_scaled_input_features = {
87
+ "input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features]
88
+ }
89
+ mel_scaled_input_features = self.feature_extractor.pad(
90
+ mel_scaled_input_features, return_tensors="pt", return_attention_mask=True
91
+ )["input_features"].transpose(1, 2)
92
+
93
+ batch["mel_scaled_input_features"] = mel_scaled_input_features
94
+ batch["speaker_id"] = (
95
+ torch.tensor([feature["speaker_id"] for feature in features]) if "speaker_id" in features[0] else None
96
+ )
97
+
98
+
99
+
100
+
101
+
102
+ # text_encoder_output = [{
103
+ # 'last_hidden_state':torch.tensor(features["text_encoder_output"]['last_hidden_state']),
104
+ # 'prior_log_variances':torch.tensor(feature["text_encoder_output"]['prior_log_variances']),
105
+ # 'prior_means':torch.tensor(feature["text_encoder_output"]['prior_means']),
106
+ # } for feature in features]
107
+
108
+ batch['text_encoder_output'] = VitsTextEncoderOutput(
109
+ last_hidden_state=torch.tensor(features[0]["text_encoder_output"]['last_hidden_state']),
110
+ prior_means=torch.tensor(features[0]["text_encoder_output"]['prior_means']),
111
+ prior_log_variances=torch.tensor(features[0]["text_encoder_output"]['prior_log_variances']),
112
+ )
113
+
114
+ # print("DataColl ",batch.keys())
115
+
116
+ return batch
117
+
118
+
119
+ #.............................................................................................
VitsModelSplit/dataset_features_collector.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import os
4
+ from datasets import Dataset,DatasetDict
5
+ from typing import Union,List,Dict
6
+ import torch
7
+ from dataclasses import dataclass
8
+ from transformers.feature_extraction_utils import BatchFeature
9
+ from VitsModelSplit.feature_extraction import VitsFeatureExtractor
10
+ from VitsModelSplit.vits_model import VitsModel
11
+ from transformers import AutoTokenizer
12
+
13
+ #.............................................
14
+
15
+
16
+ @dataclass
17
+ class DataSetFeaturesCollector:
18
+
19
+ def __init__(self,tokenizer,model,feature_extractor,forward_attention_mask=True) -> None:
20
+ self.tokenizer=tokenizer
21
+ self.feature_extractor = feature_extractor
22
+ self.model=model
23
+ self.forward_attention_mask = forward_attention_mask
24
+
25
+ #.............................................
26
+
27
+ def pad_waveform(self, raw_speech):
28
+
29
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
30
+ if is_batched_numpy and len(raw_speech.shape) > 2:
31
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
32
+ is_batched = is_batched_numpy or (
33
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
34
+ )
35
+
36
+ if is_batched:
37
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
38
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
39
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
40
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
41
+ raw_speech = raw_speech.astype(np.float32)
42
+
43
+ # always return batch
44
+ if not is_batched:
45
+ raw_speech = [np.asarray([raw_speech]).T]
46
+
47
+ batched_speech = BatchFeature({"input_features": raw_speech})
48
+
49
+ # convert into correct format for padding
50
+
51
+ padded_inputs = self.feature_extractor.pad(
52
+ batched_speech,
53
+ padding=True,
54
+ return_attention_mask=False,
55
+ return_tensors="pt",
56
+ )["input_features"]
57
+
58
+ return padded_inputs
59
+
60
+ #.............................................
61
+
62
+ def prepare_dataset(self,batch):
63
+
64
+ sample = batch['audio']
65
+ audio_inputs = self.feature_extractor(
66
+ sample,
67
+ sampling_rate=16000,
68
+ return_attention_mask=False,
69
+ do_normalize=False,
70
+ )
71
+
72
+ batch["labels"] = audio_inputs.get("input_features")[0]
73
+ batch["waveform_input_length"] = len(sample)
74
+ batch["waveform"] = batch['audio']
75
+ batch["mel_scaled_input_features"] = audio_inputs.get("mel_scaled_input_features")[0]
76
+ textsample = batch['text']
77
+ inputs = self.tokenizer(textsample, return_tensors="pt")
78
+ inputs = self.tokenizer.pad({'input_ids':inputs.input_ids})
79
+ batch['input_ids'] = inputs.input_ids
80
+ batch['attention_mask'] = inputs.attention_mask
81
+ # batch['speaker_id']=batch['speaker_id']
82
+
83
+
84
+ return batch
85
+
86
+
87
+ #.............................................
88
+
89
+
90
+ def __call__(self, dataset: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
91
+ # split inputs and labels since they have to be of different lengths and need
92
+ # different padding methods
93
+
94
+ dataset = Dataset.from_list(dataset)
95
+ features = dataset.map(
96
+ self.prepare_dataset,
97
+ remove_columns=dataset.column_names,
98
+ desc="preprocess",
99
+ )
100
+
101
+ features = list(features)
102
+
103
+ model_input_name = "input_ids"
104
+
105
+ input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features]
106
+
107
+ # pad input tokens
108
+ batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask)
109
+
110
+ # pad waveform
111
+ waveforms = [np.array(feature["waveform"]) for feature in features]
112
+ batch["waveform"] = self.pad_waveform(waveforms)
113
+
114
+ # pad spectrogram
115
+ label_features = [np.array(feature["labels"]) for feature in features]
116
+ labels_batch = self.feature_extractor.pad(
117
+ {"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True
118
+ )
119
+
120
+ labels = labels_batch["input_features"].transpose(1, 2)
121
+ batch["labels"] = labels
122
+ batch["labels_attention_mask"] = labels_batch["attention_mask"]
123
+
124
+ # pad mel spectrogram
125
+ mel_scaled_input_features = {
126
+ "input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features]
127
+ }
128
+ mel_scaled_input_features = self.feature_extractor.pad(
129
+ mel_scaled_input_features, return_tensors="pt", return_attention_mask=True
130
+ )["input_features"].transpose(1, 2)
131
+
132
+ batch["mel_scaled_input_features"] = mel_scaled_input_features
133
+ batch["speaker_id"] = (
134
+ torch.tensor([feature["speaker_id"] for feature in dataset]) if "speaker_id" in dataset[0] else None
135
+ )
136
+
137
+ # with torch.no_grad():
138
+ # padding_mask =torch.ones_like(batch['input_ids']).unsqueeze(-1).float()
139
+ # text_encoder_output = self.model.text_encoder(batch['input_ids'],
140
+ # padding_mask=padding_mask,
141
+ # attention_mask = batch['attention_mask']
142
+ # )
143
+ # batch['text_encoder_output'] = text_encoder_output
144
+ # posterior_latents, posterior_means, posterior_log_variances = self.model.posterior_encoder(
145
+ # batch['labels'], batch['labels_attention_mask'].unsqueeze(1).float()
146
+ # )
147
+ # posterior_encode_output={
148
+ # 'posterior_latents':posterior_latents,
149
+ # 'posterior_means':posterior_means,
150
+ # 'posterior_log_variances':posterior_log_variances
151
+ # }
152
+ # batch['posterior_encode_output']=posterior_encode_output
153
+
154
+
155
+
156
+ return batch
157
+
158
+
159
+ #..............................................................
160
+
161
+
162
+
163
+ #.............................................
164
+
165
+ def run_dataset_features_collection(
166
+ dataset_dir,
167
+ train_split_name ="train",
168
+ eval_split_name="eval",
169
+ full_generation_name = 'full_generation',
170
+ tokenizer = None,
171
+ model = None,
172
+ feature_extractor = None,
173
+ train_batch_size = 1,
174
+ eval_batch_size = 1,
175
+ output_dir = "dataset_features"
176
+
177
+ ):
178
+
179
+ dataset = DatasetDict.load_from_disk(dataset_dir)
180
+
181
+ data_collator = DataSetFeaturesCollector(
182
+ tokenizer = tokenizer,
183
+ model = model,
184
+ feature_extractor = feature_extractor,
185
+ forward_attention_mask = True
186
+ )
187
+
188
+ if train_split_name:
189
+ train_dataloader = torch.utils.data.DataLoader(
190
+ dataset[train_split_name],
191
+ shuffle=False,
192
+ collate_fn=data_collator,
193
+ batch_size=train_batch_size,
194
+ sampler=None,
195
+ )
196
+
197
+ train_dir = os.path.join(output_dir,"train")
198
+ os.makedirs(train_dir,exist_ok=True)
199
+
200
+ for step, batch in enumerate(train_dataloader):
201
+ print(f"Train Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
202
+ fname = os.path.join(train_dir,f"train-batch-{step}.bin")
203
+ with open(fname, "wb") as f:
204
+ torch.save(batch, f)
205
+
206
+ if eval_split_name:
207
+
208
+ eval_dataloader = torch.utils.data.DataLoader(
209
+ dataset[eval_split_name],
210
+ shuffle=False,
211
+ collate_fn=data_collator,
212
+ batch_size=eval_batch_size,
213
+ sampler=None,
214
+ )
215
+
216
+ eval_dir = os.path.join(output_dir,"eval")
217
+ os.makedirs(eval_dir,exist_ok=True)
218
+
219
+ for step, batch in enumerate(eval_dataloader):
220
+ print(f"Eval Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
221
+ fname = os.path.join(eval_dir,f"eval-batch-{step}.bin")
222
+ with open(fname, "wb") as f:
223
+ torch.save(batch, f)
224
+
225
+ if full_generation_name:
226
+
227
+ full_generation_dataloader = torch.utils.data.DataLoader(
228
+ dataset[full_generation_name],
229
+ shuffle=False,
230
+ collate_fn=data_collator,
231
+ batch_size=1,
232
+ sampler=None,
233
+ )
234
+
235
+ full_generation_dir = os.path.join(output_dir,"full_generation")
236
+ os.makedirs(full_generation_dir,exist_ok=True)
237
+
238
+ for step, batch in enumerate(full_generation_dataloader):
239
+ print(f"Full Generation Dataset - batch {step}, waveform {(batch['waveform'].shape)},tokens {(batch['input_ids'].shape)}... ")
240
+ fname = os.path.join(full_generation_dir,f"full-generation-batch-{step}.bin")
241
+ with open(fname, "wb") as f:
242
+ torch.save(batch, f)
243
+
244
+ #...........................................................................
245
+
246
+ import torch.utils.data
247
+
248
+ class FeaturesCollectionDataset(torch.utils.data.Dataset):
249
+
250
+ def __init__(self,dataset_dir,device='cpu') -> None:
251
+ self.dataset_dir = dataset_dir
252
+ self.batchs_path = sorted([os.path.join(self.dataset_dir,file) for file in os.listdir(dataset_dir) if file.endswith('.bin')])
253
+ self.device = device
254
+
255
+ def __len__(self):
256
+ return len(self.batchs_path)
257
+
258
+ def __getitem__(self, idx):
259
+ batch_name = self.batchs_path[idx]
260
+ with open(batch_name, "rb") as f:
261
+ batch = torch.load(f,map_location=torch.device(self.device))
262
+ return batch
263
+
264
+
265
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
266
+ """
267
+ Maintain similar input lengths in a batch.
268
+ Length groups are specified by boundaries.
269
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
270
+
271
+ It removes samples which are not included in the boundaries.
272
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
273
+ """
274
+ def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
275
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
276
+ self.lengths =dataset.lengths
277
+ self.batch_size = batch_size
278
+ self.boundaries = boundaries
279
+
280
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
281
+ self.total_size = sum(self.num_samples_per_bucket)
282
+ self.num_samples = self.total_size // self.num_replicas
283
+
284
+ def _create_buckets(self):
285
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
286
+ for i in range(len(self.lengths)):
287
+ length = self.lengths[i]
288
+ idx_bucket = self._bisect(length)
289
+ if idx_bucket != -1:
290
+ buckets[idx_bucket].append(i)
291
+
292
+ for i in range(len(buckets) - 1, 0, -1):
293
+ if len(buckets[i]) == 0:
294
+ buckets.pop(i)
295
+ self.boundaries.pop(i+1)
296
+
297
+ num_samples_per_bucket = []
298
+ for i in range(len(buckets)):
299
+ len_bucket = len(buckets[i])
300
+ total_batch_size = self.num_replicas * self.batch_size
301
+ rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
302
+ num_samples_per_bucket.append(len_bucket + rem)
303
+ return buckets, num_samples_per_bucket
304
+
305
+ def __iter__(self):
306
+ # deterministically shuffle based on epoch
307
+ g = torch.Generator()
308
+ g.manual_seed(self.epoch)
309
+
310
+ indices = []
311
+ if self.shuffle:
312
+ for bucket in self.buckets:
313
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
314
+ else:
315
+ for bucket in self.buckets:
316
+ indices.append(list(range(len(bucket))))
317
+
318
+ batches = []
319
+ for i in range(len(self.buckets)):
320
+ bucket = self.buckets[i]
321
+ len_bucket = len(bucket)
322
+ ids_bucket = indices[i]
323
+ num_samples_bucket = self.num_samples_per_bucket[i]
324
+
325
+ # add extra samples to make it evenly divisible
326
+ rem = num_samples_bucket - len_bucket
327
+ ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
328
+
329
+ # subsample
330
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
331
+
332
+ # batching
333
+ for j in range(len(ids_bucket) // self.batch_size):
334
+ batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]]
335
+ batches.append(batch)
336
+
337
+ if self.shuffle:
338
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
339
+ batches = [batches[i] for i in batch_ids]
340
+ self.batches = batches
341
+
342
+ assert len(self.batches) * self.batch_size == self.num_samples
343
+ return iter(self.batches)
344
+
345
+ def _bisect(self, x, lo=0, hi=None):
346
+ if hi is None:
347
+ hi = len(self.boundaries) - 1
348
+
349
+ if hi > lo:
350
+ mid = (hi + lo) // 2
351
+ if self.boundaries[mid] < x and x <= self.boundaries[mid+1]:
352
+ return mid
353
+ elif x <= self.boundaries[mid]:
354
+ return self._bisect(x, lo, mid)
355
+ else:
356
+ return self._bisect(x, mid + 1, hi)
357
+ else:
358
+ return -1
359
+
360
+ def __len__(self):
361
+ return self.num_samples // self.batch_size
362
+ class VitsCollectionDataset(torch.utils.data.Dataset):
363
+
364
+ def __init__(self,dataset,hop_length=256,rate=16_000,device='cpu') -> None:
365
+ self.dataset = dataset
366
+ self.lengths =(torch.tensor(dataset['secs'])*rate//(2*hop_length)).tolist()
367
+ self.device = device
368
+
369
+
370
+
371
+ def __len__(self):
372
+ return self.dataset.num_rows
373
+
374
+
375
+ def __getitem__(self, idx):
376
+ return self.dataset[idx]
377
+
378
+ def get_dataloader(dir_db_train,feature_extractor,name_db='train',batch_size=8,num_workers=0):
379
+ dataset = DatasetDict.load_from_disk(dir_db_train)
380
+ db_train=VitsCollectionDataset(dataset[name_db])
381
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
382
+ model=VitsModel.from_pretrained("facebook/mms-tts-ara").to(device)
383
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-ara",cache_dir="./")#.to("cuda")
384
+ train_sampler = DistributedBucketSampler(
385
+ db_train,
386
+ batch_size,
387
+ [32,300,400,500,600,700,800,900,1000],
388
+ num_replicas=1,
389
+ rank=0,
390
+ shuffle=True)
391
+ data_collator = DataSetFeaturesCollector(
392
+ tokenizer = tokenizer,
393
+ model = model,
394
+ feature_extractor = feature_extractor,
395
+ forward_attention_mask = True
396
+ )
397
+ train_dataloader = torch.utils.data.DataLoader(
398
+ db_train,
399
+ num_workers=num_workers, shuffle=False, pin_memory=True,
400
+ collate_fn=data_collator, batch_sampler=train_sampler
401
+ )
402
+ return train_dataloader
VitsModelSplit/decoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from .vits_config import VitsConfig
7
+
8
+ #.............................................
9
+
10
+
11
+
12
+
13
+
14
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
15
+ class HifiGanResidualBlock(nn.Module):
16
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
17
+ super().__init__()
18
+ self.leaky_relu_slope = leaky_relu_slope
19
+
20
+ self.convs1 = nn.ModuleList(
21
+ [
22
+ nn.Conv1d(
23
+ channels,
24
+ channels,
25
+ kernel_size,
26
+ stride=1,
27
+ dilation=dilation[i],
28
+ padding=self.get_padding(kernel_size, dilation[i]),
29
+ )
30
+ for i in range(len(dilation))
31
+ ]
32
+ )
33
+ self.convs2 = nn.ModuleList(
34
+ [
35
+ nn.Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ stride=1,
40
+ dilation=1,
41
+ padding=self.get_padding(kernel_size, 1),
42
+ )
43
+ for _ in range(len(dilation))
44
+ ]
45
+ )
46
+
47
+ def get_padding(self, kernel_size, dilation=1):
48
+ return (kernel_size * dilation - dilation) // 2
49
+
50
+ def apply_weight_norm(self):
51
+ for layer in self.convs1:
52
+ nn.utils.weight_norm(layer)
53
+ for layer in self.convs2:
54
+ nn.utils.weight_norm(layer)
55
+
56
+ def remove_weight_norm(self):
57
+ for layer in self.convs1:
58
+ nn.utils.remove_weight_norm(layer)
59
+ for layer in self.convs2:
60
+ nn.utils.remove_weight_norm(layer)
61
+
62
+ def forward(self, hidden_states):
63
+ for conv1, conv2 in zip(self.convs1, self.convs2):
64
+ residual = hidden_states
65
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
66
+ hidden_states = conv1(hidden_states)
67
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
68
+ hidden_states = conv2(hidden_states)
69
+ hidden_states = hidden_states + residual
70
+ return hidden_states
71
+
72
+
73
+ #.............................................................................................
74
+
75
+
76
+ class VitsHifiGan(nn.Module):
77
+ def __init__(self, config: VitsConfig):
78
+ super().__init__()
79
+ self.config = config
80
+ self.num_kernels = len(config.resblock_kernel_sizes)
81
+ self.num_upsamples = len(config.upsample_rates)
82
+ self.conv_pre = nn.Conv1d(
83
+ config.flow_size,
84
+ config.upsample_initial_channel,
85
+ kernel_size=7,
86
+ stride=1,
87
+ padding=3,
88
+ )
89
+
90
+ self.upsampler = nn.ModuleList()
91
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
92
+ self.upsampler.append(
93
+ nn.ConvTranspose1d(
94
+ config.upsample_initial_channel // (2**i),
95
+ config.upsample_initial_channel // (2 ** (i + 1)),
96
+ kernel_size=kernel_size,
97
+ stride=upsample_rate,
98
+ padding=(kernel_size - upsample_rate) // 2,
99
+ )
100
+ )
101
+
102
+ self.resblocks = nn.ModuleList()
103
+ for i in range(len(self.upsampler)):
104
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
105
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
106
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
107
+
108
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
109
+
110
+ if config.speaker_embedding_size != 0:
111
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
112
+
113
+ def resize_speaker_embedding(self, speaker_embedding_size):
114
+ self.config.speaker_embedding_size = speaker_embedding_size
115
+ self.cond = nn.Conv1d(speaker_embedding_size, self.config.upsample_initial_channel, 1)
116
+ nn.init.kaiming_normal_(self.cond.weight)
117
+ if self.cond.bias is not None:
118
+ k = math.sqrt(self.cond.groups / (self.cond.in_channels * self.cond.kernel_size[0]))
119
+ nn.init.uniform_(self.cond.bias, a=-k, b=k)
120
+
121
+ def apply_weight_norm(self):
122
+ for layer in self.upsampler:
123
+ nn.utils.weight_norm(layer)
124
+ for layer in self.resblocks:
125
+ layer.apply_weight_norm()
126
+
127
+ def remove_weight_norm(self):
128
+ for layer in self.upsampler:
129
+ nn.utils.remove_weight_norm(layer)
130
+ for layer in self.resblocks:
131
+ layer.remove_weight_norm()
132
+
133
+ def forward(
134
+ self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None
135
+ ) -> torch.FloatTensor:
136
+ r"""
137
+ Converts a spectrogram into a speech waveform.
138
+
139
+ Args:
140
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
141
+ Tensor containing the spectrograms.
142
+ global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
143
+ Tensor containing speaker embeddings, for multispeaker models.
144
+
145
+ Returns:
146
+ `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
147
+ """
148
+ hidden_states = self.conv_pre(spectrogram)
149
+
150
+ if global_conditioning is not None:
151
+ hidden_states = hidden_states + self.cond(global_conditioning)
152
+
153
+ for i in range(self.num_upsamples):
154
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
155
+ hidden_states = self.upsampler[i](hidden_states)
156
+
157
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
158
+ for j in range(1, self.num_kernels):
159
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
160
+ hidden_states = res_state / self.num_kernels
161
+
162
+ hidden_states = nn.functional.leaky_relu(hidden_states)
163
+ hidden_states = self.conv_post(hidden_states)
164
+ waveform = torch.tanh(hidden_states)
165
+ return waveform
166
+
167
+
168
+ #.............................................................................................
VitsModelSplit/discriminator.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ from .vits_config import VitsPreTrainedModel
5
+
6
+
7
+ #.............................................
8
+
9
+
10
+ class VitsHifiGanDiscriminatorScaleResidualBlock(nn.Module):
11
+ def __init__(self, discriminator_scale_channels, leaky_relu_slope=0.1):
12
+ super().__init__()
13
+ self.leaky_relu_slope = leaky_relu_slope
14
+
15
+ in_channels, out_channels = discriminator_scale_channels[:2]
16
+ self.convs = nn.ModuleList([nn.Conv1d(in_channels, out_channels, 15, 1, padding=7)])
17
+
18
+ groups = 4
19
+ for in_channels, out_channels in zip(discriminator_scale_channels[1:-1], discriminator_scale_channels[2:]):
20
+ self.convs.append(nn.Conv1d(in_channels, out_channels, 41, 4, groups=groups, padding=20))
21
+ groups = groups * 4
22
+
23
+ channel_size = discriminator_scale_channels[-1]
24
+ self.convs.append(nn.Conv1d(channel_size, channel_size, 41, 4, groups=groups, padding=20))
25
+ self.convs.append(nn.Conv1d(channel_size, channel_size, 5, 1, padding=2))
26
+ self.final_conv = nn.Conv1d(channel_size, 1, 3, 1, padding=1)
27
+
28
+ def apply_weight_norm(self):
29
+ for layer in self.convs:
30
+ nn.utils.weight_norm(layer)
31
+ nn.utils.weight_norm(self.final_conv)
32
+
33
+ def remove_weight_norm(self):
34
+ for layer in self.convs:
35
+ nn.utils.remove_weight_norm(layer)
36
+ nn.utils.remove_weight_norm(self.final_conv)
37
+
38
+ def forward(self, hidden_states):
39
+ fmap = []
40
+
41
+ for conv in self.convs:
42
+ hidden_states = conv(hidden_states)
43
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
44
+ fmap.append(hidden_states)
45
+
46
+ hidden_states = self.final_conv(hidden_states)
47
+ fmap.append(hidden_states)
48
+ hidden_states = torch.flatten(hidden_states, 1, -1)
49
+
50
+ return hidden_states, fmap
51
+
52
+
53
+ #.............................................................................................
54
+
55
+ class VitsHifiGanDiscriminatorPeriodResidualBlock(nn.Module):
56
+ def __init__(self, discriminator_period_channels, period, kernel_size=5, stride=3, leaky_relu_slope=0.1):
57
+ super().__init__()
58
+ self.leaky_relu_slope = leaky_relu_slope
59
+ self.period = period
60
+
61
+ self.convs = nn.ModuleList()
62
+ for in_channels, out_channels in zip(discriminator_period_channels[:-1], discriminator_period_channels[1:]):
63
+ self.convs.append(
64
+ nn.Conv2d(
65
+ in_channels,
66
+ out_channels,
67
+ (kernel_size, 1),
68
+ (stride, 1),
69
+ padding=(self.get_padding(kernel_size, 1), 0),
70
+ )
71
+ )
72
+
73
+ channel_size = discriminator_period_channels[-1]
74
+ self.convs.append(
75
+ nn.Conv2d(channel_size, channel_size, (kernel_size, 1), 1, padding=(self.get_padding(kernel_size, 1), 0))
76
+ )
77
+ self.final_conv = nn.Conv2d(channel_size, 1, (3, 1), 1, padding=(1, 0))
78
+
79
+ def get_padding(self, kernel_size, dilation=1):
80
+ return (kernel_size * dilation - dilation) // 2
81
+
82
+ def apply_weight_norm(self):
83
+ for layer in self.convs:
84
+ nn.utils.weight_norm(layer)
85
+ nn.utils.weight_norm(self.final_conv)
86
+
87
+ def remove_weight_norm(self):
88
+ for layer in self.convs:
89
+ nn.utils.remove_weight_norm(layer)
90
+ nn.utils.remove_weight_norm(self.final_conv)
91
+
92
+ def forward(self, hidden_states):
93
+ fmap = []
94
+
95
+ # from 1D to 2D
96
+ batch_size, channels, length = hidden_states.shape
97
+ if length % self.period != 0:
98
+ # pad first
99
+ n_pad = self.period - (length % self.period)
100
+ hidden_states = nn.functional.pad(hidden_states, (0, n_pad), "reflect")
101
+ length = length + n_pad
102
+ hidden_states = hidden_states.view(batch_size, channels, length // self.period, self.period)
103
+
104
+ for conv in self.convs:
105
+ hidden_states = conv(hidden_states)
106
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
107
+ fmap.append(hidden_states)
108
+
109
+ hidden_states = self.final_conv(hidden_states)
110
+ fmap.append(hidden_states)
111
+ hidden_states = torch.flatten(hidden_states, 1, -1)
112
+
113
+ return hidden_states, fmap
114
+
115
+
116
+ #.............................................................................................
117
+
118
+ class VitsDiscriminator(VitsPreTrainedModel):
119
+ def __init__(self, config):
120
+ super().__init__(config)
121
+
122
+ if config.discriminator_scale_channels is not None:
123
+ self.discriminators = nn.ModuleList(
124
+ [VitsHifiGanDiscriminatorScaleResidualBlock(config.discriminator_scale_channels, config.leaky_relu_slope)]
125
+ )
126
+ else:
127
+ self.discriminators = nn.ModuleList([])
128
+
129
+ self.discriminators.extend(
130
+ [
131
+ VitsHifiGanDiscriminatorPeriodResidualBlock(
132
+ config.discriminator_period_channels,
133
+ period,
134
+ config.discriminator_kernel_size,
135
+ config.discriminator_stride,
136
+ config.leaky_relu_slope,
137
+ )
138
+ for period in config.discriminator_periods
139
+ ]
140
+ )
141
+
142
+ def forward(self, hidden_states):
143
+ fmaps = []
144
+ discriminated_hidden_states_list = []
145
+
146
+ for discriminator in self.discriminators:
147
+ discriminated_hidden_states, fmap = discriminator(hidden_states)
148
+ fmaps.append(fmap)
149
+ discriminated_hidden_states_list.append(discriminated_hidden_states)
150
+
151
+ return discriminated_hidden_states_list, fmaps
152
+
153
+ def apply_weight_norm(self):
154
+ for disc in self.discriminators:
155
+ disc.apply_weight_norm()
156
+
157
+ def remove_weight_norm(self):
158
+ for disc in self.discriminators:
159
+ disc.remove_weight_norm()
160
+
161
+
162
+ #.............................................................................................
VitsModelSplit/duration_predictor.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .vits_config import VitsConfig
7
+
8
+ #.............................................
9
+
10
+
11
+ def _rational_quadratic_spline(
12
+ inputs,
13
+ unnormalized_widths,
14
+ unnormalized_heights,
15
+ unnormalized_derivatives,
16
+ reverse,
17
+ tail_bound,
18
+ min_bin_width,
19
+ min_bin_height,
20
+ min_derivative,
21
+ ):
22
+ """
23
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
24
+ function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
25
+
26
+ Args:
27
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
28
+ Second half of the hidden-states input to the Vits convolutional flow module.
29
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
30
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
31
+ layer in the convolutional flow module
32
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
33
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
34
+ layer in the convolutional flow module
35
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
36
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
37
+ layer in the convolutional flow module
38
+ reverse (`bool`):
39
+ Whether the model is being run in reverse mode.
40
+ tail_bound (`float`):
41
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
42
+ transform behaves as an identity function.
43
+ min_bin_width (`float`):
44
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
45
+ min_bin_height (`float`):
46
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
47
+ min_derivative (`float`):
48
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
49
+ Returns:
50
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
51
+ Hidden-states as transformed by the piecewise rational quadratic function.
52
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
53
+ Logarithm of the absolute value of the determinants corresponding to the `outputs`.
54
+ """
55
+ upper_bound = tail_bound
56
+ lower_bound = -tail_bound
57
+ if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
58
+ raise ValueError("Input to a transform is not within its domain")
59
+
60
+ num_bins = unnormalized_widths.shape[-1]
61
+
62
+ if min_bin_width * num_bins > 1.0:
63
+ raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
64
+ if min_bin_height * num_bins > 1.0:
65
+ raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
66
+
67
+ widths = nn.functional.softmax(unnormalized_widths, dim=-1)
68
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
69
+ cumwidths = torch.cumsum(widths, dim=-1)
70
+ cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
71
+ cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
72
+ cumwidths[..., 0] = lower_bound
73
+ cumwidths[..., -1] = upper_bound
74
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
75
+
76
+ derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
77
+
78
+ heights = nn.functional.softmax(unnormalized_heights, dim=-1)
79
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
80
+ cumheights = torch.cumsum(heights, dim=-1)
81
+ cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
82
+ cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
83
+ cumheights[..., 0] = lower_bound
84
+ cumheights[..., -1] = upper_bound
85
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
86
+
87
+ bin_locations = cumheights if reverse else cumwidths
88
+ bin_locations[..., -1] += 1e-6
89
+ bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
90
+ bin_idx = bin_idx[..., None]
91
+
92
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
93
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
94
+
95
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
96
+ delta = heights / widths
97
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
98
+
99
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
100
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
101
+
102
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
103
+
104
+ intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
105
+ if not reverse:
106
+ theta = (inputs - input_cumwidths) / input_bin_widths
107
+ theta_one_minus_theta = theta * (1 - theta)
108
+
109
+ numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
110
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
111
+ outputs = input_cumheights + numerator / denominator
112
+
113
+ derivative_numerator = input_delta.pow(2) * (
114
+ input_derivatives_plus_one * theta.pow(2)
115
+ + 2 * input_delta * theta_one_minus_theta
116
+ + input_derivatives * (1 - theta).pow(2)
117
+ )
118
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
119
+ return outputs, log_abs_det
120
+ else:
121
+ # find the roots of a quadratic equation
122
+ intermediate2 = inputs - input_cumheights
123
+ intermediate3 = intermediate2 * intermediate1
124
+ a = input_heights * (input_delta - input_derivatives) + intermediate3
125
+ b = input_heights * input_derivatives - intermediate3
126
+ c = -input_delta * intermediate2
127
+
128
+ discriminant = b.pow(2) - 4 * a * c
129
+ if not (discriminant >= 0).all():
130
+ raise RuntimeError(f"invalid discriminant {discriminant}")
131
+
132
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
133
+ outputs = root * input_bin_widths + input_cumwidths
134
+
135
+ theta_one_minus_theta = root * (1 - root)
136
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
137
+ derivative_numerator = input_delta.pow(2) * (
138
+ input_derivatives_plus_one * root.pow(2)
139
+ + 2 * input_delta * theta_one_minus_theta
140
+ + input_derivatives * (1 - root).pow(2)
141
+ )
142
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
143
+ return outputs, -log_abs_det
144
+
145
+ #.............................................
146
+
147
+ def _unconstrained_rational_quadratic_spline(
148
+ inputs,
149
+ unnormalized_widths,
150
+ unnormalized_heights,
151
+ unnormalized_derivatives,
152
+ reverse=False,
153
+ tail_bound=5.0,
154
+ min_bin_width=1e-3,
155
+ min_bin_height=1e-3,
156
+ min_derivative=1e-3,
157
+ ):
158
+ """
159
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
160
+ `tail_bound`, the transform behaves as an identity function.
161
+
162
+ Args:
163
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
164
+ Second half of the hidden-states input to the Vits convolutional flow module.
165
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
166
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
167
+ layer in the convolutional flow module
168
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
169
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
170
+ layer in the convolutional flow module
171
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
172
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
173
+ layer in the convolutional flow module
174
+ reverse (`bool`, *optional*, defaults to `False`):
175
+ Whether the model is being run in reverse mode.
176
+ tail_bound (`float`, *optional* defaults to 5):
177
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
178
+ transform behaves as an identity function.
179
+ min_bin_width (`float`, *optional*, defaults to 1e-3):
180
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
181
+ min_bin_height (`float`, *optional*, defaults to 1e-3):
182
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
183
+ min_derivative (`float`, *optional*, defaults to 1e-3):
184
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
185
+ Returns:
186
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
187
+ Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
188
+ applied.
189
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
190
+ Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
191
+ limits applied.
192
+ """
193
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
194
+ outside_interval_mask = ~inside_interval_mask
195
+
196
+ outputs = torch.zeros_like(inputs)
197
+ log_abs_det = torch.zeros_like(inputs)
198
+ constant = np.log(np.exp(1 - min_derivative) - 1)
199
+
200
+ unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
201
+ unnormalized_derivatives[..., 0] = constant
202
+ unnormalized_derivatives[..., -1] = constant
203
+
204
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
205
+ log_abs_det[outside_interval_mask] = 0.0
206
+
207
+ outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
208
+ inputs=inputs[inside_interval_mask],
209
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
210
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
211
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
212
+ reverse=reverse,
213
+ tail_bound=tail_bound,
214
+ min_bin_width=min_bin_width,
215
+ min_bin_height=min_bin_height,
216
+ min_derivative=min_derivative,
217
+ )
218
+ return outputs, log_abs_det
219
+
220
+
221
+ #.............................................................................................
222
+
223
+ class VitsConvFlow(nn.Module):
224
+ def __init__(self, config: VitsConfig):
225
+ super().__init__()
226
+ self.filter_channels = config.hidden_size
227
+ self.half_channels = config.depth_separable_channels // 2
228
+ self.num_bins = config.duration_predictor_flow_bins
229
+ self.tail_bound = config.duration_predictor_tail_bound
230
+
231
+ self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
232
+ self.conv_dds = VitsDilatedDepthSeparableConv(config)
233
+ self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
234
+
235
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
236
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
237
+
238
+ hidden_states = self.conv_pre(first_half)
239
+ hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
240
+ hidden_states = self.conv_proj(hidden_states) * padding_mask
241
+
242
+ batch_size, channels, length = first_half.shape
243
+ hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
244
+
245
+ unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
246
+ unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
247
+ unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
248
+
249
+ second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
250
+ second_half,
251
+ unnormalized_widths,
252
+ unnormalized_heights,
253
+ unnormalized_derivatives,
254
+ reverse=reverse,
255
+ tail_bound=self.tail_bound,
256
+ )
257
+
258
+ outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
259
+ if not reverse:
260
+ log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
261
+ return outputs, log_determinant
262
+ else:
263
+ return outputs, None
264
+
265
+
266
+ #.............................................................................................
267
+
268
+ class VitsElementwiseAffine(nn.Module):
269
+ def __init__(self, config: VitsConfig):
270
+ super().__init__()
271
+ self.channels = config.depth_separable_channels
272
+ self.translate = nn.Parameter(torch.zeros(self.channels, 1))
273
+ self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
274
+
275
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
276
+ if not reverse:
277
+ outputs = self.translate + torch.exp(self.log_scale) * inputs
278
+ outputs = outputs * padding_mask
279
+ log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
280
+ return outputs, log_determinant
281
+ else:
282
+ outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
283
+ return outputs, None
284
+
285
+ #.............................................................................................
286
+
287
+ class VitsDilatedDepthSeparableConv(nn.Module):
288
+ def __init__(self, config: VitsConfig, dropout_rate=0.0):
289
+ super().__init__()
290
+ kernel_size = config.duration_predictor_kernel_size
291
+ channels = config.hidden_size
292
+ self.num_layers = config.depth_separable_num_layers
293
+
294
+ self.dropout = nn.Dropout(dropout_rate)
295
+ self.convs_dilated = nn.ModuleList()
296
+ self.convs_pointwise = nn.ModuleList()
297
+ self.norms_1 = nn.ModuleList()
298
+ self.norms_2 = nn.ModuleList()
299
+ for i in range(self.num_layers):
300
+ dilation = kernel_size**i
301
+ padding = (kernel_size * dilation - dilation) // 2
302
+ self.convs_dilated.append(
303
+ nn.Conv1d(
304
+ in_channels=channels,
305
+ out_channels=channels,
306
+ kernel_size=kernel_size,
307
+ groups=channels,
308
+ dilation=dilation,
309
+ padding=padding,
310
+ )
311
+ )
312
+ self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
313
+ self.norms_1.append(nn.LayerNorm(channels))
314
+ self.norms_2.append(nn.LayerNorm(channels))
315
+
316
+ def forward(self, inputs, padding_mask, global_conditioning=None):
317
+ if global_conditioning is not None:
318
+ inputs = inputs + global_conditioning
319
+
320
+ for i in range(self.num_layers):
321
+ hidden_states = self.convs_dilated[i](inputs * padding_mask)
322
+ hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
323
+ hidden_states = nn.functional.gelu(hidden_states)
324
+ hidden_states = self.convs_pointwise[i](hidden_states)
325
+ hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
326
+ hidden_states = nn.functional.gelu(hidden_states)
327
+ hidden_states = self.dropout(hidden_states)
328
+ inputs = inputs + hidden_states
329
+
330
+ return inputs * padding_mask
331
+
332
+ #.............................................................................................
333
+
334
+ class VitsStochasticDurationPredictor(nn.Module):
335
+ def __init__(self, config):
336
+ super().__init__()
337
+ embed_dim = config.speaker_embedding_size
338
+ filter_channels = config.hidden_size
339
+
340
+ self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
341
+ self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
342
+ self.conv_dds = VitsDilatedDepthSeparableConv(
343
+ config,
344
+ dropout_rate=config.duration_predictor_dropout,
345
+ )
346
+
347
+ if embed_dim != 0:
348
+ self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
349
+
350
+ self.flows = nn.ModuleList()
351
+ self.flows.append(VitsElementwiseAffine(config))
352
+ for _ in range(config.duration_predictor_num_flows):
353
+ self.flows.append(VitsConvFlow(config))
354
+
355
+ self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
356
+ self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
357
+ self.post_conv_dds = VitsDilatedDepthSeparableConv(
358
+ config,
359
+ dropout_rate=config.duration_predictor_dropout,
360
+ )
361
+
362
+ self.post_flows = nn.ModuleList()
363
+ self.post_flows.append(VitsElementwiseAffine(config))
364
+ for _ in range(config.duration_predictor_num_flows):
365
+ self.post_flows.append(VitsConvFlow(config))
366
+
367
+ self.filter_channels = filter_channels
368
+
369
+ def resize_speaker_embeddings(self, speaker_embedding_size):
370
+ self.cond = nn.Conv1d(speaker_embedding_size, self.filter_channels, 1)
371
+
372
+ def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
373
+ inputs = torch.detach(inputs)
374
+ inputs = self.conv_pre(inputs)
375
+
376
+ if global_conditioning is not None:
377
+ global_conditioning = torch.detach(global_conditioning)
378
+ inputs = inputs + self.cond(global_conditioning)
379
+
380
+ inputs = self.conv_dds(inputs, padding_mask)
381
+ inputs = self.conv_proj(inputs) * padding_mask
382
+
383
+ if not reverse:
384
+ hidden_states = self.post_conv_pre(durations)
385
+ hidden_states = self.post_conv_dds(hidden_states, padding_mask)
386
+ hidden_states = self.post_conv_proj(hidden_states) * padding_mask
387
+
388
+ random_posterior = (
389
+ torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
390
+ * padding_mask
391
+ )
392
+ latents_posterior = random_posterior
393
+
394
+ latents_posterior, log_determinant = self.post_flows[0](
395
+ latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
396
+ )
397
+ log_determinant_posterior_sum = log_determinant
398
+
399
+ for flow in self.post_flows[1:]:
400
+ latents_posterior, log_determinant = flow(
401
+ latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
402
+ )
403
+ latents_posterior = torch.flip(latents_posterior, [1])
404
+ log_determinant_posterior_sum += log_determinant
405
+
406
+ first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
407
+
408
+ log_determinant_posterior_sum += torch.sum(
409
+ (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
410
+ )
411
+ logq = (
412
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
413
+ - log_determinant_posterior_sum
414
+ )
415
+
416
+ first_half = (durations - torch.sigmoid(first_half)) * padding_mask
417
+ first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
418
+ log_determinant_sum = torch.sum(-first_half, [1, 2])
419
+
420
+ latents = torch.cat([first_half, second_half], dim=1)
421
+ latents, log_determinant = self.flows[0](latents, padding_mask, global_conditioning=inputs)
422
+
423
+ log_determinant_sum += log_determinant
424
+ for flow in self.flows[1:]:
425
+ latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
426
+ latents = torch.flip(latents, [1])
427
+ log_determinant_sum += log_determinant
428
+
429
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
430
+ return nll + logq
431
+ else:
432
+ flows = list(reversed(self.flows))
433
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
434
+
435
+ latents = (
436
+ torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
437
+ * noise_scale
438
+ )
439
+ for flow in flows:
440
+ latents = torch.flip(latents, [1])
441
+ latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
442
+
443
+ log_duration, _ = torch.split(latents, [1, 1], dim=1)
444
+ return log_duration
445
+
446
+ #.............................................................................................
447
+
448
+ class VitsDurationPredictor(nn.Module):
449
+ def __init__(self, config):
450
+ super().__init__()
451
+ kernel_size = config.duration_predictor_kernel_size
452
+ filter_channels = config.duration_predictor_filter_channels
453
+
454
+ self.dropout = nn.Dropout(config.duration_predictor_dropout)
455
+ self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
456
+ self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
457
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
458
+ self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
459
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
460
+
461
+ if config.speaker_embedding_size != 0:
462
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
463
+
464
+ self.hidden_size = config.hidden_size
465
+
466
+ def resize_speaker_embeddings(self, speaker_embedding_size):
467
+ self.cond = nn.Conv1d(speaker_embedding_size, self.hidden_size, 1)
468
+
469
+ def forward(self, inputs, padding_mask, global_conditioning=None):
470
+ inputs = torch.detach(inputs)
471
+
472
+ if global_conditioning is not None:
473
+ global_conditioning = torch.detach(global_conditioning)
474
+ inputs = inputs + self.cond(global_conditioning)
475
+
476
+ inputs = self.conv_1(inputs * padding_mask)
477
+ inputs = torch.relu(inputs)
478
+ inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
479
+ inputs = self.dropout(inputs)
480
+
481
+ inputs = self.conv_2(inputs * padding_mask)
482
+ inputs = torch.relu(inputs)
483
+ inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
484
+ inputs = self.dropout(inputs)
485
+
486
+ inputs = self.proj(inputs * padding_mask)
487
+ return inputs * padding_mask
488
+
489
+ #.............................................................................................
VitsModelSplit/encoder.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
8
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+
11
+ from .vits_config import VitsConfig
12
+ from .vits_output import VitsTextEncoderOutput
13
+
14
+
15
+ #....................................................
16
+
17
+
18
+
19
+
20
+
21
+ class VitsFeedForward(nn.Module):
22
+ def __init__(self, config):
23
+ super().__init__()
24
+ self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
25
+ self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
26
+ self.dropout = nn.Dropout(config.activation_dropout)
27
+
28
+ if isinstance(config.hidden_act, str):
29
+ self.act_fn = ACT2FN[config.hidden_act]
30
+ else:
31
+ self.act_fn = config.hidden_act
32
+
33
+ if config.ffn_kernel_size > 1:
34
+ pad_left = (config.ffn_kernel_size - 1) // 2
35
+ pad_right = config.ffn_kernel_size // 2
36
+ self.padding = [pad_left, pad_right, 0, 0, 0, 0]
37
+ else:
38
+ self.padding = None
39
+
40
+ def forward(self, hidden_states, padding_mask):
41
+ hidden_states = hidden_states.permute(0, 2, 1)
42
+ padding_mask = padding_mask.permute(0, 2, 1)
43
+
44
+ hidden_states = hidden_states * padding_mask
45
+ if self.padding is not None:
46
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
47
+
48
+ hidden_states = self.conv_1(hidden_states)
49
+ hidden_states = self.act_fn(hidden_states)
50
+ hidden_states = self.dropout(hidden_states)
51
+
52
+ hidden_states = hidden_states * padding_mask
53
+ if self.padding is not None:
54
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
55
+
56
+ hidden_states = self.conv_2(hidden_states)
57
+ hidden_states = hidden_states * padding_mask
58
+
59
+ hidden_states = hidden_states.permute(0, 2, 1)
60
+ return hidden_states
61
+
62
+
63
+ #.............................................................................................
64
+
65
+ class VitsAttention(nn.Module):
66
+ """Multi-headed attention with relative positional representation."""
67
+
68
+ def __init__(self, config: VitsConfig):
69
+ super().__init__()
70
+ self.embed_dim = config.hidden_size
71
+ self.num_heads = config.num_attention_heads
72
+ self.dropout = config.attention_dropout
73
+ self.window_size = config.window_size
74
+
75
+ self.head_dim = self.embed_dim // self.num_heads
76
+ self.scaling = self.head_dim**-0.5
77
+
78
+ if (self.head_dim * self.num_heads) != self.embed_dim:
79
+ raise ValueError(
80
+ f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
81
+ f" and `num_attention_heads`: {self.num_heads})."
82
+ )
83
+
84
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
85
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
86
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
87
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
88
+
89
+ if self.window_size:
90
+ self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
91
+ self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
92
+
93
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
94
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
95
+
96
+ def forward(
97
+ self,
98
+ hidden_states: torch.Tensor,
99
+ key_value_states: Optional[torch.Tensor] = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ layer_head_mask: Optional[torch.Tensor] = None,
102
+ output_attentions: bool = False,
103
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
104
+ """Input shape: Batch x Time x Channel"""
105
+
106
+ # if key_value_states are provided this layer is used as a cross-attention layer
107
+ # for the decoder
108
+
109
+ bsz, tgt_len, _ = hidden_states.size()
110
+
111
+ # get query proj
112
+ query_states = self.q_proj(hidden_states) * self.scaling
113
+
114
+ # self_attention
115
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
116
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
117
+
118
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
119
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
120
+ key_states = key_states.view(*proj_shape)
121
+ value_states = value_states.view(*proj_shape)
122
+
123
+ src_len = key_states.size(1)
124
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
125
+
126
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
127
+ raise ValueError(
128
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
129
+ f" {attn_weights.size()}"
130
+ )
131
+
132
+ if self.window_size is not None:
133
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
134
+ relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
135
+ rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
136
+ attn_weights += rel_pos_bias
137
+
138
+ if attention_mask is not None:
139
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
140
+ raise ValueError(
141
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
142
+ )
143
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
144
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
145
+
146
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
147
+
148
+ if layer_head_mask is not None:
149
+ if layer_head_mask.size() != (self.num_heads,):
150
+ raise ValueError(
151
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
152
+ f" {layer_head_mask.size()}"
153
+ )
154
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
155
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
156
+
157
+ if output_attentions:
158
+ # this operation is a bit awkward, but it's required to
159
+ # make sure that attn_weights keeps its gradient.
160
+ # In order to do so, attn_weights have to be reshaped
161
+ # twice and have to be reused in the following
162
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
163
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
164
+ else:
165
+ attn_weights_reshaped = None
166
+
167
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
168
+
169
+ attn_output = torch.bmm(attn_probs, value_states)
170
+
171
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
172
+ raise ValueError(
173
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
174
+ f" {attn_output.size()}"
175
+ )
176
+
177
+ if self.window_size is not None:
178
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
179
+ relative_weights = self._absolute_position_to_relative_position(attn_probs)
180
+ rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
181
+ attn_output += rel_pos_bias
182
+
183
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
184
+ attn_output = attn_output.transpose(1, 2)
185
+
186
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
187
+ # partitioned aross GPUs when using tensor-parallelism.
188
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
189
+
190
+ attn_output = self.out_proj(attn_output)
191
+
192
+ return attn_output, attn_weights_reshaped
193
+
194
+ def _get_relative_embeddings(self, relative_embeddings, length):
195
+ pad_length = max(length - (self.window_size + 1), 0)
196
+ if pad_length > 0:
197
+ relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
198
+
199
+ slice_start_position = max((self.window_size + 1) - length, 0)
200
+ slice_end_position = slice_start_position + 2 * length - 1
201
+ return relative_embeddings[:, slice_start_position:slice_end_position]
202
+
203
+ def _relative_position_to_absolute_position(self, x):
204
+ batch_heads, length, _ = x.size()
205
+
206
+ # Concat columns of pad to shift from relative to absolute indexing.
207
+ x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
208
+
209
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
210
+ x_flat = x.view([batch_heads, length * 2 * length])
211
+ x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
212
+
213
+ # Reshape and slice out the padded elements.
214
+ x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
215
+ x_final = x_final[:, :length, length - 1 :]
216
+ return x_final
217
+
218
+ def _absolute_position_to_relative_position(self, x):
219
+ batch_heads, length, _ = x.size()
220
+
221
+ # Pad along column
222
+ x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
223
+ x_flat = x.view([batch_heads, length**2 + length * (length - 1)])
224
+
225
+ # Add 0's in the beginning that will skew the elements after reshape
226
+ x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
227
+ x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
228
+ return x_final
229
+
230
+
231
+ #.............................................................................................
232
+
233
+ class VitsEncoderLayer(nn.Module):
234
+ def __init__(self, config: VitsConfig):
235
+ super().__init__()
236
+ self.attention = VitsAttention(config)
237
+ self.dropout = nn.Dropout(config.hidden_dropout)
238
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
239
+ self.feed_forward = VitsFeedForward(config)
240
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ padding_mask: torch.FloatTensor,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ output_attentions: bool = False,
248
+ ):
249
+ residual = hidden_states
250
+ hidden_states, attn_weights = self.attention(
251
+ hidden_states=hidden_states,
252
+ attention_mask=attention_mask,
253
+ output_attentions=output_attentions,
254
+ )
255
+
256
+ hidden_states = self.dropout(hidden_states)
257
+ hidden_states = self.layer_norm(residual + hidden_states)
258
+
259
+ residual = hidden_states
260
+ hidden_states = self.feed_forward(hidden_states, padding_mask)
261
+ hidden_states = self.dropout(hidden_states)
262
+ hidden_states = self.final_layer_norm(residual + hidden_states)
263
+
264
+ outputs = (hidden_states,)
265
+
266
+ if output_attentions:
267
+ outputs += (attn_weights,)
268
+
269
+ return outputs
270
+
271
+ #.............................................................................................
272
+
273
+ class VitsEncoder(nn.Module):
274
+ def __init__(self, config: VitsConfig):
275
+ super().__init__()
276
+ self.config = config
277
+ self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
278
+ self.gradient_checkpointing = False
279
+ self.layerdrop = config.layerdrop
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states: torch.FloatTensor,
284
+ padding_mask: torch.FloatTensor,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ output_attentions: Optional[bool] = None,
287
+ output_hidden_states: Optional[bool] = None,
288
+ return_dict: Optional[bool] = None,
289
+ ) -> Union[Tuple, BaseModelOutput]:
290
+ all_hidden_states = () if output_hidden_states else None
291
+ all_self_attentions = () if output_attentions else None
292
+
293
+ # expand attention_mask
294
+ if attention_mask is not None:
295
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
296
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
297
+
298
+ hidden_states = hidden_states * padding_mask
299
+
300
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
301
+
302
+ for encoder_layer in self.layers:
303
+ if output_hidden_states:
304
+ all_hidden_states = all_hidden_states + (hidden_states,)
305
+
306
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
307
+ dropout_probability = np.random.uniform(0, 1)
308
+
309
+ skip_the_layer = self.training and (dropout_probability < self.layerdrop)
310
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
311
+ # under deepspeed zero3 all gpus must run in sync
312
+ if self.gradient_checkpointing and self.training:
313
+ layer_outputs = self._gradient_checkpointing_func(
314
+ encoder_layer.__call__,
315
+ hidden_states,
316
+ padding_mask,
317
+ attention_mask,
318
+ output_attentions,
319
+ )
320
+ else:
321
+ layer_outputs = encoder_layer(
322
+ hidden_states,
323
+ attention_mask=attention_mask,
324
+ padding_mask=padding_mask,
325
+ output_attentions=output_attentions,
326
+ )
327
+ hidden_states = layer_outputs[0]
328
+
329
+ if skip_the_layer:
330
+ layer_outputs = (None, None)
331
+
332
+ if output_attentions:
333
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
334
+
335
+ hidden_states = hidden_states * padding_mask
336
+
337
+ if output_hidden_states:
338
+ all_hidden_states = all_hidden_states + (hidden_states,)
339
+
340
+ if not return_dict:
341
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
342
+
343
+ return BaseModelOutput(
344
+ last_hidden_state=hidden_states,
345
+ hidden_states=all_hidden_states,
346
+ attentions=all_self_attentions,
347
+ )
348
+
349
+ #.............................................................................................
350
+
351
+ class VitsTextEncoder(nn.Module):
352
+ """
353
+ Transformer encoder that uses relative positional representation instead of absolute positional encoding.
354
+ """
355
+
356
+ def __init__(self, config: VitsConfig):
357
+ super().__init__()
358
+ self.config = config
359
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
360
+
361
+ self.encoder = VitsEncoder(config)
362
+ self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
363
+
364
+ def get_input_embeddings(self):
365
+ return self.embed_tokens
366
+
367
+ def set_input_embeddings(self, value):
368
+ self.embed_tokens = value
369
+
370
+ def forward(
371
+ self,
372
+ input_ids: torch.Tensor,
373
+ padding_mask: torch.FloatTensor,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ output_attentions: Optional[bool] = None,
376
+ output_hidden_states: Optional[bool] = None,
377
+ return_dict: Optional[bool] = True,
378
+ ) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]:
379
+ hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
380
+
381
+ encoder_outputs = self.encoder(
382
+ hidden_states=hidden_states,
383
+ padding_mask=padding_mask,
384
+ attention_mask=attention_mask,
385
+ output_attentions=output_attentions,
386
+ output_hidden_states=output_hidden_states,
387
+ return_dict=return_dict,
388
+ )
389
+
390
+ last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
391
+
392
+ stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
393
+ prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
394
+
395
+ if not return_dict:
396
+ outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
397
+ return outputs
398
+
399
+ return VitsTextEncoderOutput(
400
+ last_hidden_state=last_hidden_state,
401
+ prior_means=prior_means,
402
+ prior_log_variances=prior_log_variances,
403
+ hidden_states=encoder_outputs.hidden_states,
404
+ attentions=encoder_outputs.attentions,
405
+ )
406
+
407
+ #.............................................................................................
VitsModelSplit/feature_extraction.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature extractor class for Vits
3
+ """
4
+ import copy
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+ from transformers import is_torch_available
10
+ from transformers.audio_utils import mel_filter_bank
11
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
12
+ from transformers.feature_extraction_utils import BatchFeature
13
+ from transformers.utils import TensorType, logging
14
+
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+ if is_torch_available():
19
+ import torch
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VitsFeatureExtractor(SequenceFeatureExtractor):
25
+ r"""
26
+ Constructs a Vits feature extractor.
27
+
28
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
29
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
30
+
31
+ This class extracts `Short Time Fourier Transform` from raw speech using a custom numpy implementation which should
32
+ match pytorch's `torch.stft`.
33
+
34
+ Args:
35
+ feature_size (`int`, defaults to 80):
36
+ The feature dimension of the extracted features.
37
+ sampling_rate (`int`, defaults to 22050):
38
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
39
+ hop_length (`int`, defaults to 256):
40
+ Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
41
+ n_fft (`int`, defaults to 1024):
42
+ Size of the Fourier transform.
43
+ padding_value (`float`, *optional*, defaults to 0.0):
44
+ Padding value used to pad the audio. Should correspond to silences.
45
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
46
+ Whether to return the attention mask.
47
+
48
+ [What are attention masks?](../glossary#attention-mask)
49
+
50
+ <Tip>
51
+
52
+ For Vits finetuning, `attention_mask` should always be passed for batched inference, to avoid subtle bugs.
53
+
54
+ </Tip>
55
+
56
+ max_wav_value (`float`, defaults to 32768.0):
57
+ Maximum wav value. Used to normalize the input waveforms if `do_normalize=True` in the forward pass of this
58
+ feature extractor.
59
+ """
60
+
61
+ model_input_names = ["input_features"]
62
+
63
+ def __init__(
64
+ self,
65
+ feature_size=80,
66
+ sampling_rate=16000,
67
+ hop_length=256,
68
+ n_fft=1024,
69
+ padding_value=0.0,
70
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask,
71
+ max_wav_value=32768.0,
72
+ **kwargs,
73
+ ):
74
+ super().__init__(
75
+ feature_size=feature_size,
76
+ sampling_rate=sampling_rate,
77
+ padding_value=padding_value,
78
+ return_attention_mask=return_attention_mask,
79
+ **kwargs,
80
+ )
81
+ self.n_fft = n_fft
82
+ self.hop_length = hop_length
83
+ self.sampling_rate = sampling_rate
84
+ self.mel_filters = mel_filter_bank(
85
+ num_frequency_bins=1 + n_fft // 2,
86
+ num_mel_filters=feature_size,
87
+ min_frequency=0.0,
88
+ max_frequency=sampling_rate // 2,
89
+ sampling_rate=sampling_rate,
90
+ norm="slaney",
91
+ mel_scale="slaney",
92
+ )
93
+ self.max_wav_value = max_wav_value
94
+
95
+ def _torch_extract_fbank_features(self, waveform: np.array) -> Tuple[torch.Tensor]:
96
+ """
97
+ Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
98
+ """
99
+ if len(waveform.shape) == 1:
100
+ waveform = waveform.unsqueeze(0)
101
+
102
+ waveform = torch.nn.functional.pad(
103
+ waveform,
104
+ (int((self.n_fft - self.hop_length) / 2), int((self.n_fft - self.hop_length) / 2)),
105
+ mode="reflect",
106
+ )
107
+
108
+ window = torch.hann_window(self.n_fft).to(waveform.device)
109
+ stft = torch.stft(
110
+ waveform,
111
+ self.n_fft,
112
+ hop_length=self.hop_length,
113
+ win_length=self.n_fft,
114
+ window=window,
115
+ center=False,
116
+ pad_mode="reflect",
117
+ normalized=False,
118
+ onesided=True,
119
+ return_complex=False,
120
+ )
121
+ magnitudes = torch.sqrt(stft.pow(2).sum(-1) + 1e-6)
122
+
123
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(waveform.device)
124
+ mel_spec = mel_filters.T @ magnitudes
125
+
126
+ log_spec = torch.clamp(mel_spec, min=1e-5).log()
127
+ return magnitudes, log_spec
128
+
129
+ def __call__(
130
+ self,
131
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
132
+ truncation: bool = False,
133
+ pad_to_multiple_of: Optional[int] = None,
134
+ return_tensors: Optional[Union[str, TensorType]] = None,
135
+ return_attention_mask: Optional[bool] = True,
136
+ padding: Optional[str] = True,
137
+ max_length: Optional[int] = None,
138
+ sampling_rate: Optional[int] = None,
139
+ do_normalize: Optional[bool] = None,
140
+ **kwargs,
141
+ ) -> BatchFeature:
142
+ """
143
+ Main method to featurize and prepare for the model one or several sequence(s).
144
+
145
+ Args:
146
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
147
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
148
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
149
+ stereo, i.e. single float per timestep.
150
+ truncation (`bool`, *optional*, default to `False`):
151
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
152
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
153
+ If set will pad the sequence to a multiple of the provided value.
154
+
155
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
156
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
157
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
158
+ If set, will return tensors instead of list of python integers. Acceptable values are:
159
+
160
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
161
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
162
+ - `'np'`: Return Numpy `np.ndarray` objects.
163
+ return_attention_mask (`bool`, *optional*, defaults to `True`):
164
+ Whether to return the attention mask. If left to the default, will return the attention mask according
165
+ to the specific feature_extractor's default.
166
+
167
+ [What are attention masks?](../glossary#attention-mask)
168
+
169
+ <Tip>
170
+
171
+ For Vits finetuning, `attention_mask` should always be passed for batched inference, to avoid subtle
172
+ bugs.
173
+
174
+ </Tip>
175
+
176
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
177
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
178
+ index) among:
179
+
180
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
181
+ sequence if provided).
182
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
183
+ acceptable input length for the model if that argument is not provided.
184
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
185
+ lengths).
186
+ max_length (`int`, *optional*):
187
+ Maximum length of the returned list and optionally padding length (see above).
188
+ sampling_rate (`int`, *optional*):
189
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
190
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
191
+ pipeline.
192
+ do_normalize (`bool`, *optional*):
193
+ Whether or not to divide the input waveform by `self.max_wav_value`.
194
+ """
195
+
196
+ if sampling_rate is not None:
197
+ if sampling_rate != self.sampling_rate:
198
+ raise ValueError(
199
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
200
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
201
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
202
+ )
203
+ else:
204
+ logger.warning(
205
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
206
+ "Failing to do so can result in silent errors that might be hard to debug."
207
+ )
208
+
209
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
210
+ if is_batched_numpy and len(raw_speech.shape) > 2:
211
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
212
+ is_batched = is_batched_numpy or (
213
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
214
+ )
215
+
216
+ if is_batched:
217
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
218
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
219
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
220
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
221
+ raw_speech = raw_speech.astype(np.float32)
222
+
223
+ # always return batch
224
+ if not is_batched:
225
+ raw_speech = [np.asarray([raw_speech]).T]
226
+
227
+ if self.max_wav_value is not None and do_normalize:
228
+ raw_speech = [
229
+ speech if self.max_wav_value is None else speech / self.max_wav_value for speech in raw_speech
230
+ ]
231
+
232
+ batched_speech = BatchFeature({"input_features": raw_speech})
233
+
234
+ # convert into correct format for padding
235
+ padded_inputs = self.pad(
236
+ batched_speech,
237
+ padding=padding,
238
+ max_length=max_length,
239
+ truncation=truncation,
240
+ pad_to_multiple_of=pad_to_multiple_of,
241
+ return_attention_mask=return_attention_mask or do_normalize,
242
+ return_tensors="pt",
243
+ )
244
+
245
+ # make sure list is in array format
246
+ if isinstance(padded_inputs.get("input_features"),list):
247
+ input_features = torch.tensor(padded_inputs.get("input_features")).transpose(1, 2).transpose(0, 1)
248
+ else:
249
+ input_features = padded_inputs.get("input_features").clone().detach().transpose(1, 2).transpose(0, 1)
250
+
251
+
252
+ input_features = self._torch_extract_fbank_features(input_features[0])
253
+
254
+ mel_scaled_input_features = input_features[1]
255
+ input_features = input_features[0]
256
+
257
+ padded_inputs["input_features"] = input_features
258
+ padded_inputs["mel_scaled_input_features"] = mel_scaled_input_features
259
+
260
+ if return_attention_mask:
261
+ # rescale from sample (48000) to feature (3000)
262
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
263
+
264
+ if return_tensors is not None:
265
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
266
+
267
+ return padded_inputs
268
+
269
+ def to_dict(self) -> Dict[str, Any]:
270
+ """
271
+ Serializes this instance to a Python dictionary.
272
+
273
+ Returns:
274
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
275
+ """
276
+ output = copy.deepcopy(self.__dict__)
277
+ output["feature_extractor_type"] = self.__class__.__name__
278
+ if "mel_filters" in output:
279
+ del output["mel_filters"]
280
+ return output
VitsModelSplit/finetune_config_ara.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "project_name": "vits_ara",
3
+ "push_to_hub": false,
4
+ "hub_model_id": "ara_tts_finetuning/ara_tts_finetuning",
5
+ "overwrite_output_dir": true,
6
+ "output_dir": "./output",
7
+
8
+ "dataset_name": "./dataset/",
9
+ "dataset_config_name": "welsh_female",
10
+ "audio_column_name": "audio",
11
+ "text_column_name":"text",
12
+ "train_split_name": "train",
13
+ "eval_split_name": "eval",
14
+
15
+ "override_speaker_embeddings": false,
16
+ "filter_on_speaker_id": 5223,
17
+
18
+
19
+ "max_duration_in_seconds": 20,
20
+ "min_duration_in_seconds": 1.0,
21
+ "max_tokens_length": 500,
22
+
23
+ "model_name_or_path": "facebook/mms-tts-ara",
24
+
25
+ "full_generation_sample_text": "اوريه و اخليه يعرف هو حاط نفسه في مواجهة مع مين",
26
+ "preprocessing_num_workers": 4,
27
+
28
+ "do_train": true,
29
+ "num_train_epochs": 300,
30
+ "gradient_accumulation_steps": 1,
31
+ "per_device_train_batch_size": 10,
32
+ "learning_rate": 2e-5,
33
+ "adam_beta1": 0.8,
34
+ "adam_beta2": 0.99,
35
+ "warmup_ratio": 0.01,
36
+ "d_learning_rate": 2e-5,
37
+ "d_adam_beta1": 0.7,
38
+ "d_adam_beta2": 0.99,
39
+
40
+ "do_eval": true,
41
+ "eval_steps": 10,
42
+ "per_device_eval_batch_size": 10,
43
+ "max_eval_samples": 2,
44
+ "do_step_schedule_per_epoch": true,
45
+
46
+ "weight_disc": 3,
47
+ "weight_fmaps": 1,
48
+ "weight_gen": 1,
49
+ "weight_kl": 1.5,
50
+ "weight_duration": 1,
51
+ "weight_mel": 35,
52
+
53
+ "fp16": false,
54
+ "seed": 456
55
+ }
VitsModelSplit/flow.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional
5
+ from .vits_config import VitsConfig
6
+ #.............................................
7
+
8
+ @torch.jit.script
9
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
10
+ in_act = input_a + input_b
11
+ t_act = torch.tanh(in_act[:, :num_channels, :])
12
+ s_act = torch.sigmoid(in_act[:, num_channels:, :])
13
+ acts = t_act * s_act
14
+ return acts
15
+
16
+
17
+
18
+ #.............................................
19
+
20
+ class VitsWaveNet(torch.nn.Module):
21
+ def __init__(self, config: VitsConfig, num_layers: int):
22
+ super().__init__()
23
+ self.hidden_size = config.hidden_size
24
+ self.num_layers = num_layers
25
+ self.speaker_embedding_size = config.speaker_embedding_size
26
+
27
+ self.in_layers = torch.nn.ModuleList()
28
+ self.res_skip_layers = torch.nn.ModuleList()
29
+ self.dropout = nn.Dropout(config.wavenet_dropout)
30
+
31
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
32
+ weight_norm = nn.utils.parametrizations.weight_norm
33
+ else:
34
+ weight_norm = nn.utils.weight_norm
35
+
36
+ if config.speaker_embedding_size != 0:
37
+ cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
38
+ self.cond_layer = weight_norm(cond_layer, name="weight")
39
+
40
+ for i in range(num_layers):
41
+ dilation = config.wavenet_dilation_rate**i
42
+ padding = (config.wavenet_kernel_size * dilation - dilation) // 2
43
+ in_layer = torch.nn.Conv1d(
44
+ in_channels=config.hidden_size,
45
+ out_channels=2 * config.hidden_size,
46
+ kernel_size=config.wavenet_kernel_size,
47
+ dilation=dilation,
48
+ padding=padding,
49
+ )
50
+ in_layer = weight_norm(in_layer, name="weight")
51
+ self.in_layers.append(in_layer)
52
+
53
+ # last one is not necessary
54
+ if i < num_layers - 1:
55
+ res_skip_channels = 2 * config.hidden_size
56
+ else:
57
+ res_skip_channels = config.hidden_size
58
+
59
+ res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
60
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
61
+ self.res_skip_layers.append(res_skip_layer)
62
+
63
+ def forward(self, inputs, padding_mask, global_conditioning=None):
64
+ outputs = torch.zeros_like(inputs)
65
+ num_channels_tensor = torch.IntTensor([self.hidden_size])
66
+
67
+ if global_conditioning is not None:
68
+ global_conditioning = self.cond_layer(global_conditioning)
69
+
70
+ for i in range(self.num_layers):
71
+ hidden_states = self.in_layers[i](inputs)
72
+
73
+ if global_conditioning is not None:
74
+ cond_offset = i * 2 * self.hidden_size
75
+ global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
76
+ else:
77
+ global_states = torch.zeros_like(hidden_states)
78
+
79
+ acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
80
+ acts = self.dropout(acts)
81
+
82
+ res_skip_acts = self.res_skip_layers[i](acts)
83
+ if i < self.num_layers - 1:
84
+ res_acts = res_skip_acts[:, : self.hidden_size, :]
85
+ inputs = (inputs + res_acts) * padding_mask
86
+ outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
87
+ else:
88
+ outputs = outputs + res_skip_acts
89
+
90
+ return outputs * padding_mask
91
+
92
+ def remove_weight_norm(self):
93
+ if self.speaker_embedding_size != 0:
94
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
95
+ for layer in self.in_layers:
96
+ torch.nn.utils.remove_weight_norm(layer)
97
+ for layer in self.res_skip_layers:
98
+ torch.nn.utils.remove_weight_norm(layer)
99
+
100
+ def apply_weight_norm(self):
101
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
102
+ weight_norm = nn.utils.parametrizations.weight_norm
103
+ else:
104
+ weight_norm = nn.utils.weight_norm
105
+
106
+ if self.speaker_embedding_size != 0:
107
+ weight_norm(self.cond_layer)
108
+ for layer in self.in_layers:
109
+ weight_norm(layer)
110
+ for layer in self.res_skip_layers:
111
+ weight_norm(layer)
112
+
113
+
114
+ #.............................................................................................
115
+
116
+ class VitsResidualCouplingLayer(nn.Module):
117
+ def __init__(self, config: VitsConfig):
118
+ super().__init__()
119
+ self.half_channels = config.flow_size // 2
120
+
121
+ self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
122
+ self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
123
+ self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
124
+
125
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
126
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
127
+ hidden_states = self.conv_pre(first_half) * padding_mask
128
+ hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
129
+ mean = self.conv_post(hidden_states) * padding_mask
130
+ log_stddev = torch.zeros_like(mean)
131
+
132
+ if not reverse:
133
+ second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
134
+ outputs = torch.cat([first_half, second_half], dim=1)
135
+ log_determinant = torch.sum(log_stddev, [1, 2])
136
+ return outputs, log_determinant
137
+ else:
138
+ second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
139
+ outputs = torch.cat([first_half, second_half], dim=1)
140
+ return outputs, None
141
+
142
+ def apply_weight_norm(self):
143
+ nn.utils.weight_norm(self.conv_pre)
144
+ self.wavenet.apply_weight_norm()
145
+ nn.utils.weight_norm(self.conv_post)
146
+
147
+ def remove_weight_norm(self):
148
+ nn.utils.remove_weight_norm(self.conv_pre)
149
+ self.wavenet.remove_weight_norm()
150
+ nn.utils.remove_weight_norm(self.conv_post)
151
+
152
+
153
+
154
+ #.............................................................................................
155
+
156
+ class VitsResidualCouplingBlock(nn.Module):
157
+ def __init__(self, config: VitsConfig):
158
+ super().__init__()
159
+ self.flows = nn.ModuleList()
160
+ for _ in range(config.prior_encoder_num_flows):
161
+ self.flows.append(VitsResidualCouplingLayer(config))
162
+
163
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
164
+ if not reverse:
165
+ for flow in self.flows:
166
+ inputs, _ = flow(inputs, padding_mask, global_conditioning)
167
+ inputs = torch.flip(inputs, [1])
168
+ else:
169
+ for flow in reversed(self.flows):
170
+ inputs = torch.flip(inputs, [1])
171
+ inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
172
+ return inputs
173
+
174
+ def apply_weight_norm(self):
175
+ for flow in self.flows:
176
+ flow.apply_weight_norm()
177
+
178
+ def remove_weight_norm(self):
179
+ for flow in self.flows:
180
+ flow.remove_weight_norm()
181
+
182
+ def resize_speaker_embeddings(self, speaker_embedding_size: Optional[int] = None):
183
+ for flow in self.flows:
184
+ flow.wavenet.speaker_embedding_size = speaker_embedding_size
185
+ hidden_size = flow.wavenet.hidden_size
186
+ num_layers = flow.wavenet.num_layers
187
+
188
+ cond_layer = torch.nn.Conv1d(speaker_embedding_size, 2 * hidden_size * num_layers, 1)
189
+ flow.wavenet.cond_layer = nn.utils.weight_norm(cond_layer, name="weight")
190
+
VitsModelSplit/mk ADDED
@@ -0,0 +1 @@
 
 
1
+
VitsModelSplit/monotonic_align/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .monotonic_align.core import maximum_path_c
4
+
5
+
6
+ def maximum_path(neg_cent, mask):
7
+ """ Cython optimized version.
8
+ neg_cent: [b, t_t, t_s]
9
+ mask: [b, t_t, t_s]
10
+ """
11
+ device = neg_cent.device
12
+ dtype = neg_cent.dtype
13
+ neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14
+ path = np.zeros(neg_cent.shape, dtype=np.int32)
15
+
16
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
VitsModelSplit/monotonic_align/core.c ADDED
The diff for this file is too large to render. See raw diff
 
VitsModelSplit/monotonic_align/core.pyx ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cimport cython
2
+ from cython.parallel import prange
3
+
4
+
5
+ @cython.boundscheck(False)
6
+ @cython.wraparound(False)
7
+ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8
+ cdef int x
9
+ cdef int y
10
+ cdef float v_prev
11
+ cdef float v_cur
12
+ cdef float tmp
13
+ cdef int index = t_x - 1
14
+
15
+ for y in range(t_y):
16
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17
+ if x == y:
18
+ v_cur = max_neg_val
19
+ else:
20
+ v_cur = value[y-1, x]
21
+ if x == 0:
22
+ if y == 0:
23
+ v_prev = 0.
24
+ else:
25
+ v_prev = max_neg_val
26
+ else:
27
+ v_prev = value[y-1, x-1]
28
+ value[y, x] += max(v_prev, v_cur)
29
+
30
+ for y in range(t_y - 1, -1, -1):
31
+ path[y, index] = 1
32
+ if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33
+ index = index - 1
34
+
35
+
36
+ @cython.boundscheck(False)
37
+ @cython.wraparound(False)
38
+ cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39
+ cdef int b = paths.shape[0]
40
+ cdef int i
41
+ for i in prange(b, nogil=True):
42
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
VitsModelSplit/monotonic_align/data ADDED
@@ -0,0 +1 @@
 
 
1
+
VitsModelSplit/monotonic_align/setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.core import setup
2
+ from Cython.Build import cythonize
3
+ import numpy
4
+
5
+ setup(
6
+ name = 'monotonic_align',
7
+ ext_modules = cythonize("core.pyx"),
8
+ include_dirs=[numpy.get_include()]
9
+ )
VitsModelSplit/plot.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import matplotlib
3
+
4
+
5
+ matplotlib.use("Agg")
6
+
7
+ MATPLOTLIB_FLAG = False
8
+
9
+
10
+ def plot_spectrogram_to_numpy(spectrogram):
11
+ global MATPLOTLIB_FLAG
12
+ if not MATPLOTLIB_FLAG:
13
+ import matplotlib
14
+
15
+ matplotlib.use("Agg")
16
+ MATPLOTLIB_FLAG = True
17
+ mpl_logger = logging.getLogger("matplotlib")
18
+ mpl_logger.setLevel(logging.WARNING)
19
+ import matplotlib.pylab as plt
20
+ import numpy as np
21
+
22
+ fig, ax = plt.subplots(figsize=(10, 2))
23
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
24
+
25
+ plt.colorbar(im, ax=ax)
26
+ plt.xlabel("Frames")
27
+ plt.ylabel("Channels")
28
+ plt.tight_layout()
29
+ fig.canvas.draw()
30
+
31
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
32
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
33
+ plt.close()
34
+ return data
35
+
36
+
37
+ def plot_alignment_to_numpy(alignment, info=None):
38
+ global MATPLOTLIB_FLAG
39
+ if not MATPLOTLIB_FLAG:
40
+ import matplotlib
41
+
42
+ matplotlib.use("Agg")
43
+ MATPLOTLIB_FLAG = True
44
+ mpl_logger = logging.getLogger("matplotlib")
45
+ mpl_logger.setLevel(logging.WARNING)
46
+ import matplotlib.pylab as plt
47
+ import numpy as np
48
+
49
+ fig, ax = plt.subplots(figsize=(6, 4))
50
+ im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none")
51
+ fig.colorbar(im, ax=ax)
52
+ xlabel = "Decoder timestep"
53
+ if info is not None:
54
+ xlabel += "\n\n" + info
55
+ plt.xlabel(xlabel)
56
+ plt.ylabel("Encoder timestep")
57
+ plt.tight_layout()
58
+
59
+ fig.canvas.draw()
60
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
61
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
62
+ plt.close()
63
+ return data
VitsModelSplit/posterior_encoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ import torch
4
+ from torch import nn
5
+ from .vits_config import VitsConfig
6
+ from .flow import VitsWaveNet
7
+
8
+ #.............................................
9
+
10
+
11
+
12
+ class VitsPosteriorEncoder(nn.Module):
13
+ def __init__(self, config: VitsConfig):
14
+ super().__init__()
15
+ self.out_channels = config.flow_size
16
+
17
+ self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
18
+ self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
19
+ self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
20
+
21
+ def forward(self, inputs, padding_mask, global_conditioning=None):
22
+ inputs = self.conv_pre(inputs) * padding_mask
23
+ inputs = self.wavenet(inputs, padding_mask, global_conditioning)
24
+ stats = self.conv_proj(inputs) * padding_mask
25
+ mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
26
+ sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
27
+ return sampled, mean, log_stddev
28
+
29
+ def apply_weight_norm(self):
30
+ self.wavenet.apply_weight_norm()
31
+
32
+ def remove_weight_norm(self):
33
+ self.wavenet.remove_weight_norm()
34
+
35
+ def resize_speaker_embeddings(self, speaker_embedding_size: Optional[int] = None):
36
+ self.wavenet.speaker_embedding_size = speaker_embedding_size
37
+ hidden_size = self.wavenet.hidden_size
38
+ num_layers = self.wavenet.num_layers
39
+
40
+ cond_layer = torch.nn.Conv1d(speaker_embedding_size, 2 * hidden_size * num_layers, 1)
41
+ self.wavenet.cond_layer = nn.utils.weight_norm(cond_layer, name="weight")
42
+ nn.init.kaiming_normal_(self.wavenet.cond_layer.weight)
43
+ if self.wavenet.cond_layer.bias is not None:
44
+ k = math.sqrt(
45
+ self.wavenet.cond_layer.groups
46
+ / (self.wavenet.cond_layer.in_channels * self.wavenet.cond_layer.kernel_size[0])
47
+ )
48
+ nn.init.uniform_(self.wavenet.cond_layer.bias, a=-k, b=k)
49
+
50
+ #.............................................................................................
VitsModelSplit/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython==0.29.21
2
+ librosa==0.8.0
3
+ matplotlib==3.3.1
4
+ numpy==1.18.5
5
+ phonemizer==2.2.1
6
+ scipy==1.5.2
7
+ tensorboard==2.3.0
8
+ torch==1.6.0
9
+ torchvision==0.7.0
10
+ Unidecode==1.1.1
VitsModelSplit/vits_config.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.modeling_utils import PreTrainedModel
4
+ from torch import nn
5
+
6
+ #.............................................
7
+
8
+
9
+
10
+ class VitsConfig(PretrainedConfig):
11
+ model_type = "vits"
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=38,
16
+ hidden_size=192,
17
+ num_hidden_layers=6,
18
+ num_attention_heads=2,
19
+ window_size=4,
20
+ use_bias=True,
21
+ ffn_dim=768,
22
+ layerdrop=0.1,
23
+ ffn_kernel_size=3,
24
+ flow_size=192,
25
+ spectrogram_bins=513,
26
+ hidden_act="relu",
27
+ hidden_dropout=0.1,
28
+ attention_dropout=0.1,
29
+ activation_dropout=0.1,
30
+ initializer_range=0.02,
31
+ layer_norm_eps=1e-5,
32
+ use_stochastic_duration_prediction=True,
33
+ num_speakers=1,
34
+ speaker_embedding_size=0,
35
+ upsample_initial_channel=512,
36
+ upsample_rates=[8, 8, 2, 2],
37
+ upsample_kernel_sizes=[16, 16, 4, 4],
38
+ resblock_kernel_sizes=[3, 7, 11],
39
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
40
+ leaky_relu_slope=0.1,
41
+ depth_separable_channels=2,
42
+ depth_separable_num_layers=3,
43
+ duration_predictor_flow_bins=10,
44
+ duration_predictor_tail_bound=5.0,
45
+ duration_predictor_kernel_size=3,
46
+ duration_predictor_dropout=0.5,
47
+ duration_predictor_num_flows=4,
48
+ duration_predictor_filter_channels=256,
49
+ prior_encoder_num_flows=4,
50
+ prior_encoder_num_wavenet_layers=4,
51
+ posterior_encoder_num_wavenet_layers=16,
52
+ wavenet_kernel_size=5,
53
+ wavenet_dilation_rate=1,
54
+ wavenet_dropout=0.0,
55
+ speaking_rate=1.0,
56
+ noise_scale=0.667,
57
+ noise_scale_duration=0.8,
58
+ sampling_rate=16_000,
59
+ discriminator_kernel_size=5,
60
+ discriminator_stride=3,
61
+ discriminator_periods=[2, 3, 5, 7, 11],
62
+ discriminator_period_channels=[1, 32, 128, 512, 1024],
63
+ discriminator_scale_channels=[1, 16, 64, 256, 1024],
64
+ segment_size=8192,
65
+ hop_length=256,
66
+ **kwargs,
67
+ ):
68
+ self.vocab_size = vocab_size
69
+ self.hidden_size = hidden_size
70
+ self.num_hidden_layers = num_hidden_layers
71
+ self.num_attention_heads = num_attention_heads
72
+ self.window_size = window_size
73
+ self.use_bias = use_bias
74
+ self.ffn_dim = ffn_dim
75
+ self.layerdrop = layerdrop
76
+ self.ffn_kernel_size = ffn_kernel_size
77
+ self.flow_size = flow_size
78
+ self.spectrogram_bins = spectrogram_bins
79
+ self.hidden_act = hidden_act
80
+ self.hidden_dropout = hidden_dropout
81
+ self.attention_dropout = attention_dropout
82
+ self.activation_dropout = activation_dropout
83
+ self.initializer_range = initializer_range
84
+ self.layer_norm_eps = layer_norm_eps
85
+ self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
86
+ self.num_speakers = num_speakers
87
+ self.speaker_embedding_size = speaker_embedding_size
88
+ self.upsample_initial_channel = upsample_initial_channel
89
+ self.upsample_rates = upsample_rates
90
+ self.upsample_kernel_sizes = upsample_kernel_sizes
91
+ self.resblock_kernel_sizes = resblock_kernel_sizes
92
+ self.resblock_dilation_sizes = resblock_dilation_sizes
93
+ self.leaky_relu_slope = leaky_relu_slope
94
+ self.depth_separable_channels = depth_separable_channels
95
+ self.depth_separable_num_layers = depth_separable_num_layers
96
+ self.duration_predictor_flow_bins = duration_predictor_flow_bins
97
+ self.duration_predictor_tail_bound = duration_predictor_tail_bound
98
+ self.duration_predictor_kernel_size = duration_predictor_kernel_size
99
+ self.duration_predictor_dropout = duration_predictor_dropout
100
+ self.duration_predictor_num_flows = duration_predictor_num_flows
101
+ self.duration_predictor_filter_channels = duration_predictor_filter_channels
102
+ self.prior_encoder_num_flows = prior_encoder_num_flows
103
+ self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
104
+ self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
105
+ self.wavenet_kernel_size = wavenet_kernel_size
106
+ self.wavenet_dilation_rate = wavenet_dilation_rate
107
+ self.wavenet_dropout = wavenet_dropout
108
+ self.speaking_rate = speaking_rate
109
+ self.noise_scale = noise_scale
110
+ self.noise_scale_duration = noise_scale_duration
111
+ self.sampling_rate = sampling_rate
112
+
113
+ # used for training
114
+ self.discriminator_kernel_size = discriminator_kernel_size
115
+ self.discriminator_stride = discriminator_stride
116
+ self.discriminator_periods = discriminator_periods
117
+ self.discriminator_period_channels = discriminator_period_channels
118
+ self.discriminator_scale_channels = discriminator_scale_channels
119
+ self.segment_size = segment_size
120
+ self.hop_length = hop_length
121
+
122
+ if len(upsample_kernel_sizes) != len(upsample_rates):
123
+ raise ValueError(
124
+ f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
125
+ f"`upsample_rates` ({len(upsample_rates)})"
126
+ )
127
+
128
+ super().__init__(**kwargs)
129
+
130
+ #.............................................................................................
131
+
132
+ class VitsPreTrainedModel(PreTrainedModel):
133
+ """
134
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
135
+ models.
136
+ """
137
+ config_class = VitsConfig
138
+ base_model_prefix = "vits"
139
+ main_input_name = "input_ids"
140
+ supports_gradient_checkpointing = True
141
+
142
+ def _init_weights(self, module):
143
+ """Initialize the weights"""
144
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
145
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
146
+ if module.bias is not None:
147
+ module.bias.data.zero_()
148
+ elif isinstance(module, nn.LayerNorm):
149
+ module.bias.data.zero_()
150
+ module.weight.data.fill_(1.0)
151
+ elif isinstance(module, nn.Conv1d):
152
+ nn.init.kaiming_normal_(module.weight)
153
+ if module.bias is not None:
154
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
155
+ nn.init.uniform_(module.bias, a=-k, b=k)
156
+ elif isinstance(module, nn.Embedding):
157
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
158
+ if module.padding_idx is not None:
159
+ module.weight.data[module.padding_idx].zero_()
160
+
161
+
162
+ #.............................................................................................
VitsModelSplit/vits_model.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ import math
6
+ from typing import Any, Callable, Optional, Tuple, Union
7
+ from torch.cuda.amp import autocast, GradScaler
8
+
9
+ from .vits_config import VitsConfig,VitsPreTrainedModel
10
+ from .flow import VitsResidualCouplingBlock
11
+ from .duration_predictor import VitsDurationPredictor, VitsStochasticDurationPredictor
12
+ from .encoder import VitsTextEncoder
13
+ from .decoder import VitsHifiGan
14
+ from .posterior_encoder import VitsPosteriorEncoder
15
+ from .discriminator import VitsDiscriminator
16
+ from .vits_output import VitsModelOutput, VitsTrainingOutput
17
+
18
+
19
+ class VitsModel(VitsPreTrainedModel):
20
+
21
+ def __init__(self, config: VitsConfig):
22
+ super().__init__(config)
23
+
24
+ self.config = config
25
+ self.text_encoder = VitsTextEncoder(config)
26
+ self.flow = VitsResidualCouplingBlock(config)
27
+ self.decoder = VitsHifiGan(config)
28
+
29
+
30
+
31
+ if config.use_stochastic_duration_prediction:
32
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
33
+ else:
34
+ self.duration_predictor = VitsDurationPredictor(config)
35
+
36
+ if config.num_speakers > 1:
37
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
38
+
39
+ # This is used only for training.
40
+ self.posterior_encoder = VitsPosteriorEncoder(config)
41
+ self.discriminator = VitsDiscriminator(config)
42
+
43
+ # These parameters control the synthesised speech properties
44
+ self.speaking_rate = config.speaking_rate
45
+ self.noise_scale = config.noise_scale
46
+ self.noise_scale_duration = config.noise_scale_duration
47
+ self.segment_size = self.config.segment_size // self.config.hop_length
48
+
49
+ # Initialize weights and apply final processing
50
+ self.post_init()
51
+
52
+
53
+ #....................................
54
+
55
+ def monotonic_align_max_path(self,log_likelihoods, mask):
56
+ # used for training - awfully slow
57
+ # an alternative is proposed in examples/pytorch/text-to-speech/run_vits_finetuning.py
58
+ path = torch.zeros_like(log_likelihoods)
59
+
60
+ text_length_maxs = mask.sum(1)[:, 0]
61
+ latent_length_maxs = mask.sum(2)[:, 0]
62
+
63
+ indexes = latent_length_maxs - 1
64
+
65
+ max_neg_val = -1e9
66
+
67
+ for batch_id in range(len(path)):
68
+ index = int(indexes[batch_id].item())
69
+ text_length_max = int(text_length_maxs[batch_id].item())
70
+ latent_length_max = int(latent_length_maxs[batch_id].item())
71
+
72
+ for y in range(text_length_max):
73
+ for x in range(max(0, latent_length_max + y - text_length_max), min(latent_length_max, y + 1)):
74
+ if x == y:
75
+ v_cur = max_neg_val
76
+ else:
77
+ v_cur = log_likelihoods[batch_id, y - 1, x]
78
+ if x == 0:
79
+ if y == 0:
80
+ v_prev = 0.0
81
+ else:
82
+ v_prev = max_neg_val
83
+ else:
84
+ v_prev = log_likelihoods[batch_id, y - 1, x - 1]
85
+ log_likelihoods[batch_id, y, x] += max(v_prev, v_cur)
86
+
87
+ for y in range(text_length_max - 1, -1, -1):
88
+ path[batch_id, y, index] = 1
89
+ if index != 0 and (
90
+ index == y or log_likelihoods[batch_id, y - 1, index] < log_likelihoods[batch_id, y - 1, index - 1]
91
+ ):
92
+ index = index - 1
93
+ return path
94
+
95
+ #....................................
96
+
97
+ def slice_segments(self,hidden_states, ids_str, segment_size=4):
98
+
99
+ batch_size, channels, _ = hidden_states.shape
100
+ # 1d tensor containing the indices to keep
101
+ indices = torch.arange(segment_size).to(ids_str.device)
102
+ # extend the indices to match the shape of hidden_states
103
+ indices = indices.view(1, 1, -1).expand(batch_size, channels, -1)
104
+ # offset indices with ids_str
105
+ indices = indices + ids_str.view(-1, 1, 1)
106
+ # gather indices
107
+ output = torch.gather(hidden_states, dim=2, index=indices)
108
+
109
+ return output
110
+
111
+
112
+ #....................................
113
+
114
+
115
+ def rand_slice_segments(self,hidden_states, sample_lengths=None, segment_size=4):
116
+
117
+ batch_size, _, seq_len = hidden_states.size()
118
+ if sample_lengths is None:
119
+ sample_lengths = seq_len
120
+ ids_str_max = sample_lengths - segment_size + 1
121
+ ids_str = (torch.rand([batch_size]).to(device=hidden_states.device) * ids_str_max).to(dtype=torch.long)
122
+ ret = self.slice_segments(hidden_states, ids_str, segment_size)
123
+
124
+ return ret, ids_str
125
+
126
+ #....................................
127
+
128
+ def resize_speaker_embeddings(
129
+ self,
130
+ new_num_speakers: int,
131
+ speaker_embedding_size: Optional[int] = None,
132
+ pad_to_multiple_of: Optional[int] = 2,
133
+ ):
134
+ if pad_to_multiple_of is not None:
135
+ new_num_speakers = ((new_num_speakers + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
136
+
137
+ # first, take care of embed_speaker
138
+ if self.config.num_speakers <= 1:
139
+ if speaker_embedding_size is None:
140
+ raise ValueError(
141
+ "The current model had no previous speaker embedding, but `speaker_embedding_size` is not specified. Pass `speaker_embedding_size` to this method."
142
+ )
143
+ # create new embedding layer
144
+ new_embeddings = nn.Embedding(
145
+ new_num_speakers,
146
+ speaker_embedding_size,
147
+ device=self.device,
148
+ )
149
+ # initialize all new embeddings
150
+ self._init_weights(new_embeddings)
151
+ else:
152
+ new_embeddings = self._get_resized_embeddings(self.embed_speaker, new_num_speakers)
153
+
154
+ self.embed_speaker = new_embeddings
155
+
156
+ # then take care of sub-models
157
+ self.flow.resize_speaker_embeddings(speaker_embedding_size)
158
+ for flow in self.flow.flows:
159
+ self._init_weights(flow.wavenet.cond_layer)
160
+
161
+ self.decoder.resize_speaker_embedding(speaker_embedding_size)
162
+ self._init_weights(self.decoder.cond)
163
+
164
+ self.duration_predictor.resize_speaker_embeddings(speaker_embedding_size)
165
+ self._init_weights(self.duration_predictor.cond)
166
+
167
+ self.posterior_encoder.resize_speaker_embeddings(speaker_embedding_size)
168
+ self._init_weights(self.posterior_encoder.wavenet.cond_layer)
169
+
170
+ self.config.num_speakers = new_num_speakers
171
+ self.config.speaker_embedding_size = speaker_embedding_size
172
+
173
+ #....................................
174
+
175
+ def get_input_embeddings(self):
176
+ return self.text_encoder.get_input_embeddings()
177
+
178
+ #....................................
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.text_encoder.set_input_embeddings(value)
182
+
183
+ #....................................
184
+
185
+ def apply_weight_norm(self):
186
+ self.decoder.apply_weight_norm()
187
+ self.flow.apply_weight_norm()
188
+ self.posterior_encoder.apply_weight_norm()
189
+
190
+ #....................................
191
+
192
+ def remove_weight_norm(self):
193
+ self.decoder.remove_weight_norm()
194
+ self.flow.remove_weight_norm()
195
+ self.posterior_encoder.remove_weight_norm()
196
+
197
+ #....................................
198
+
199
+ def discriminate(self, hidden_states):
200
+ return self.discriminator(hidden_states)
201
+
202
+ #....................................
203
+
204
+ def get_encoder(self):
205
+ return self.text_encoder
206
+
207
+ #....................................
208
+
209
+ def _inference_forward(
210
+ self,
211
+ input_ids: Optional[torch.Tensor] = None,
212
+ attention_mask: Optional[torch.Tensor] = None,
213
+ speaker_embeddings: Optional[torch.Tensor] = None,
214
+ output_attentions: Optional[bool] = None,
215
+ output_hidden_states: Optional[bool] = None,
216
+ return_dict: Optional[bool] = None,
217
+ padding_mask: Optional[torch.Tensor] = None,
218
+ ):
219
+ text_encoder_output = self.text_encoder(
220
+ input_ids=input_ids,
221
+ padding_mask=padding_mask,
222
+ attention_mask=attention_mask,
223
+ output_attentions=output_attentions,
224
+ output_hidden_states=output_hidden_states,
225
+ return_dict=return_dict,
226
+ )
227
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
228
+ hidden_states = hidden_states.transpose(1, 2)
229
+ input_padding_mask = padding_mask.transpose(1, 2)
230
+
231
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
232
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
233
+
234
+ if self.config.use_stochastic_duration_prediction:
235
+ log_duration = self.duration_predictor(
236
+ hidden_states,
237
+ input_padding_mask,
238
+ speaker_embeddings,
239
+ reverse=True,
240
+ noise_scale=self.noise_scale_duration,
241
+ )
242
+ else:
243
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
244
+
245
+ length_scale = 1.0 / self.speaking_rate
246
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
247
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
248
+
249
+
250
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
251
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
252
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
253
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
254
+
255
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
256
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
257
+ batch_size, _, output_length, input_length = attn_mask.shape
258
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
259
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
260
+ valid_indices = indices.unsqueeze(0) < cum_duration
261
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
262
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
263
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
264
+
265
+ # Expand prior distribution
266
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
267
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
268
+
269
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
270
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
271
+
272
+ spectrogram = latents * output_padding_mask
273
+ waveform = self.decoder(spectrogram, speaker_embeddings)
274
+ waveform = waveform.squeeze(1)
275
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
276
+
277
+ if not return_dict:
278
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
279
+ return outputs
280
+
281
+ return VitsModelOutput(
282
+ waveform=waveform,
283
+ sequence_lengths=sequence_lengths,
284
+ spectrogram=spectrogram,
285
+ hidden_states=text_encoder_output.hidden_states,
286
+ attentions=text_encoder_output.attentions,
287
+ )
288
+
289
+ #....................................
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.Tensor] = None,
294
+ attention_mask: Optional[torch.Tensor] = None,
295
+ speaker_id: Optional[int] = None,
296
+ output_attentions: Optional[bool] = None,
297
+ output_hidden_states: Optional[bool] = None,
298
+ return_dict: Optional[bool] = None,
299
+ labels: Optional[torch.FloatTensor] = None,
300
+ labels_attention_mask: Optional[torch.Tensor] = None,
301
+ monotonic_alignment_function: Optional[Callable] = None,
302
+ ) -> Union[Tuple[Any], VitsModelOutput]:
303
+
304
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
305
+ output_hidden_states = (
306
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
307
+ )
308
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
309
+
310
+ monotonic_alignment_function = (
311
+ self.monotonic_align_max_path if monotonic_alignment_function is None else monotonic_alignment_function
312
+ )
313
+
314
+ if attention_mask is not None:
315
+ input_padding_mask = attention_mask.unsqueeze(-1).float()
316
+ else:
317
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
318
+
319
+ if self.config.num_speakers > 1 and speaker_id is not None:
320
+ if isinstance(speaker_id, int):
321
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
322
+ elif isinstance(speaker_id, (list, tuple, np.ndarray)):
323
+ speaker_id = torch.tensor(speaker_id, device=self.device)
324
+
325
+ if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item():
326
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
327
+ if not (len(speaker_id) == 1 or len(speaker_id == len(input_ids))):
328
+ raise ValueError(
329
+ f"You passed {len(speaker_id)} `speaker_id` but you should either pass one speaker id or `batch_size` `speaker_id`."
330
+ )
331
+
332
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
333
+ else:
334
+ speaker_embeddings = None
335
+
336
+ # if inference, return inference forward of VitsModel
337
+ if labels is None:
338
+ return self._inference_forward(
339
+ input_ids,
340
+ attention_mask,
341
+ speaker_embeddings,
342
+ output_attentions,
343
+ output_hidden_states,
344
+ return_dict,
345
+ input_padding_mask,
346
+ )
347
+
348
+ if labels_attention_mask is not None:
349
+ labels_padding_mask = labels_attention_mask.unsqueeze(1).float()
350
+ else:
351
+ labels_attention_mask = torch.ones((labels.shape[0], labels.shape[2])).float().to(self.device)
352
+ labels_padding_mask = labels_attention_mask.unsqueeze(1)
353
+
354
+ text_encoder_output = self.text_encoder(
355
+ input_ids=input_ids,
356
+ padding_mask=input_padding_mask,
357
+ attention_mask=attention_mask,
358
+ output_attentions=output_attentions,
359
+ output_hidden_states=output_hidden_states,
360
+ return_dict=return_dict,
361
+ )
362
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
363
+ hidden_states = hidden_states.transpose(1, 2)
364
+ input_padding_mask = input_padding_mask.transpose(1, 2)
365
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
366
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
367
+
368
+ latents, posterior_means, posterior_log_variances = self.posterior_encoder(
369
+ labels, labels_padding_mask, speaker_embeddings
370
+ )
371
+ prior_latents = self.flow(latents, labels_padding_mask, speaker_embeddings, reverse=False)
372
+
373
+ prior_means, prior_log_variances = prior_means.transpose(1, 2), prior_log_variances.transpose(1, 2)
374
+ with torch.no_grad():
375
+ # negative cross-entropy
376
+
377
+ # [batch_size, d, latent_length]
378
+ prior_variances = torch.exp(-2 * prior_log_variances)
379
+ # [batch_size, 1, latent_length]
380
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - prior_log_variances, [1], keepdim=True)
381
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
382
+ neg_cent2 = torch.matmul(-0.5 * (prior_latents**2).transpose(1, 2), prior_variances)
383
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
384
+ neg_cent3 = torch.matmul(prior_latents.transpose(1, 2), (prior_means * prior_variances))
385
+ # [batch_size, 1, latent_length]
386
+ neg_cent4 = torch.sum(-0.5 * (prior_means**2) * prior_variances, [1], keepdim=True)
387
+
388
+ # [batch_size, text_length, latent_length]
389
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
390
+
391
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(labels_padding_mask, -1)
392
+
393
+ attn = monotonic_alignment_function(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
394
+
395
+ durations = attn.sum(2)
396
+
397
+ if self.config.use_stochastic_duration_prediction:
398
+ log_duration = self.duration_predictor(
399
+ hidden_states, input_padding_mask, speaker_embeddings, durations=durations, reverse=False
400
+ )
401
+ log_duration = log_duration / torch.sum(input_padding_mask)
402
+ else:
403
+ log_duration_padded = torch.log(durations + 1e-6) * input_padding_mask
404
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
405
+ log_duration = torch.sum((log_duration - log_duration_padded) ** 2, [1, 2]) / torch.sum(input_padding_mask)
406
+
407
+ # expand priors
408
+ prior_means = torch.matmul(attn.squeeze(1), prior_means.transpose(1, 2)).transpose(1, 2)
409
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances.transpose(1, 2)).transpose(1, 2)
410
+
411
+ label_lengths = labels_attention_mask.sum(dim=1)
412
+ latents_slice, ids_slice = self.rand_slice_segments(latents, label_lengths, segment_size=self.segment_size)
413
+
414
+ waveform = self.decoder(latents_slice, speaker_embeddings)
415
+
416
+ if not return_dict:
417
+ outputs = (
418
+ waveform,
419
+ log_duration,
420
+ attn,
421
+ ids_slice,
422
+ input_padding_mask,
423
+ labels_padding_mask,
424
+ latents,
425
+ prior_latents,
426
+ prior_means,
427
+ prior_log_variances,
428
+ posterior_means,
429
+ posterior_log_variances,
430
+ )
431
+ return outputs
432
+
433
+ return VitsTrainingOutput(
434
+ waveform=waveform,
435
+ log_duration=log_duration,
436
+ attn=attn,
437
+ ids_slice=ids_slice,
438
+ input_padding_mask=input_padding_mask,
439
+ labels_padding_mask=labels_padding_mask,
440
+ latents=latents,
441
+ prior_latents=prior_latents,
442
+ prior_means=prior_means,
443
+ prior_log_variances=prior_log_variances,
444
+ posterior_means=posterior_means,
445
+ posterior_log_variances=posterior_log_variances,
446
+ )
447
+
VitsModelSplit/vits_model2.py ADDED
The diff for this file is too large to render. See raw diff
 
VitsModelSplit/vits_model3.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ import math
6
+ from typing import Any, Callable, Optional, Tuple, Union
7
+ from torch.cuda.amp import autocast, GradScaler
8
+
9
+ from .vits_config import VitsConfig,VitsPreTrainedModel
10
+ from .flow import VitsResidualCouplingBlock
11
+ from .duration_predictor import VitsDurationPredictor, VitsStochasticDurationPredictor
12
+ from .encoder import VitsTextEncoder
13
+ from .decoder import VitsHifiGan
14
+ from .posterior_encoder import VitsPosteriorEncoder
15
+ from .discriminator import VitsDiscriminator
16
+ from .vits_output import VitsModelOutput, VitsTrainingOutput
17
+ from .dataset_features_collector import FeaturesCollectionDataset
18
+ from .feature_extraction import VitsFeatureExtractor
19
+
20
+ import os
21
+ import sys
22
+ from typing import Optional
23
+ import tempfile
24
+ from torch.cuda.amp import autocast, GradScaler
25
+
26
+ from IPython.display import clear_output
27
+ from transformers import set_seed
28
+ import wandb
29
+ import logging
30
+ import copy
31
+ Lst=['input_ids',
32
+ 'attention_mask',
33
+ 'waveform',
34
+ 'labels',
35
+ 'labels_attention_mask',
36
+ 'mel_scaled_input_features']
37
+
38
+ def covert_cuda_batch(d):
39
+ #return d
40
+ for key in Lst:
41
+ d[key]=d[key].cuda(non_blocking=True)
42
+ # for key in d['text_encoder_output']:
43
+ # d['text_encoder_output'][key]=d['text_encoder_output'][key].cuda(non_blocking=True)
44
+ for key in d['posterior_encode_output']:
45
+ d['posterior_encode_output'][key]=d['posterior_encode_output'][key].cuda(non_blocking=True)
46
+
47
+ return d
48
+ def generator_loss(disc_outputs):
49
+ total_loss = 0
50
+ gen_losses = []
51
+ for disc_output in disc_outputs:
52
+ disc_output = disc_output
53
+ loss = torch.mean((1 - disc_output) ** 2)
54
+ gen_losses.append(loss)
55
+ total_loss += loss
56
+
57
+ return total_loss, gen_losses
58
+
59
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
60
+ loss = 0
61
+ real_losses = 0
62
+ generated_losses = 0
63
+ for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs):
64
+ real_loss = torch.mean((1 - disc_real) ** 2)
65
+ generated_loss = torch.mean(disc_generated**2)
66
+ loss += real_loss + generated_loss
67
+ real_losses += real_loss
68
+ generated_losses += generated_loss
69
+
70
+ return loss, real_losses, generated_losses
71
+
72
+ def feature_loss(feature_maps_real, feature_maps_generated):
73
+ loss = 0
74
+ for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated):
75
+ for real, generated in zip(feature_map_real, feature_map_generated):
76
+ real = real.detach()
77
+ loss += torch.mean(torch.abs(real - generated))
78
+
79
+ return loss * 2
80
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
81
+ """
82
+ z_p, logs_q: [b, h, t_t]
83
+ m_p, logs_p: [b, h, t_t]
84
+ """
85
+ z_p = z_p.float()
86
+ logs_q = logs_q.float()
87
+ m_p = m_p.float()
88
+ logs_p = logs_p.float()
89
+ z_mask = z_mask.float()
90
+
91
+ kl = logs_p - logs_q - 0.5
92
+ kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
93
+ kl = torch.sum(kl * z_mask)
94
+ l = kl / torch.sum(z_mask)
95
+ return l
96
+ #.............................................
97
+ # def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask):
98
+
99
+
100
+ # kl = prior_log_variance - posterior_log_variance - 0.5
101
+ # kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance)
102
+ # kl = torch.sum(kl * labels_mask)
103
+ # loss = kl / torch.sum(labels_mask)
104
+ # return loss
105
+
106
+ def get_state_grad_loss(k1=True,
107
+ mel=True,
108
+ duration=True,
109
+ generator=True,
110
+ discriminator=True):
111
+ return {'k1':k1,'mel':mel,'duration':duration,'generator':generator,'discriminator':discriminator}
112
+
113
+
114
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
115
+ if isinstance(parameters, torch.Tensor):
116
+ parameters = [parameters]
117
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
118
+ norm_type = float(norm_type)
119
+ if clip_value is not None:
120
+ clip_value = float(clip_value)
121
+
122
+ total_norm = 0
123
+ for p in parameters:
124
+ param_norm = p.grad.data.norm(norm_type)
125
+ total_norm += param_norm.item() ** norm_type
126
+ if clip_value is not None:
127
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
128
+ total_norm = total_norm ** (1. / norm_type)
129
+ return total_norm
130
+
131
+
132
+ class VitsModel(VitsPreTrainedModel):
133
+
134
+ def __init__(self, config: VitsConfig):
135
+ super().__init__(config)
136
+
137
+ self.config = config
138
+ self.text_encoder = VitsTextEncoder(config)
139
+ self.flow = VitsResidualCouplingBlock(config)
140
+ self.decoder = VitsHifiGan(config)
141
+
142
+
143
+
144
+ if config.use_stochastic_duration_prediction:
145
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
146
+ else:
147
+ self.duration_predictor = VitsDurationPredictor(config)
148
+
149
+ if config.num_speakers > 1:
150
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
151
+
152
+ # This is used only for training.
153
+ self.posterior_encoder = VitsPosteriorEncoder(config)
154
+ self.discriminator = VitsDiscriminator(config)
155
+
156
+ # These parameters control the synthesised speech properties
157
+ self.speaking_rate = config.speaking_rate
158
+ self.noise_scale = config.noise_scale
159
+ self.noise_scale_duration = config.noise_scale_duration
160
+ self.segment_size = self.config.segment_size // self.config.hop_length
161
+
162
+ # Initialize weights and apply final processing
163
+ self.post_init()
164
+ self.monotonic_alignment_function=self.monotonic_align_max_path
165
+
166
+
167
+
168
+ #....................................
169
+ def setMfA(self,fn):
170
+ self.monotonic_alignment_function=fn
171
+
172
+
173
+
174
+ def monotonic_align_max_path(self,log_likelihoods, mask):
175
+ # used for training - awfully slow
176
+ # an alternative is proposed in examples/pytorch/text-to-speech/run_vits_finetuning.py
177
+ path = torch.zeros_like(log_likelihoods)
178
+
179
+ text_length_maxs = mask.sum(1)[:, 0]
180
+ latent_length_maxs = mask.sum(2)[:, 0]
181
+
182
+ indexes = latent_length_maxs - 1
183
+
184
+ max_neg_val = -1e9
185
+
186
+ for batch_id in range(len(path)):
187
+ index = int(indexes[batch_id].item())
188
+ text_length_max = int(text_length_maxs[batch_id].item())
189
+ latent_length_max = int(latent_length_maxs[batch_id].item())
190
+
191
+ for y in range(text_length_max):
192
+ for x in range(max(0, latent_length_max + y - text_length_max), min(latent_length_max, y + 1)):
193
+ if x == y:
194
+ v_cur = max_neg_val
195
+ else:
196
+ v_cur = log_likelihoods[batch_id, y - 1, x]
197
+ if x == 0:
198
+ if y == 0:
199
+ v_prev = 0.0
200
+ else:
201
+ v_prev = max_neg_val
202
+ else:
203
+ v_prev = log_likelihoods[batch_id, y - 1, x - 1]
204
+ log_likelihoods[batch_id, y, x] += max(v_prev, v_cur)
205
+
206
+ for y in range(text_length_max - 1, -1, -1):
207
+ path[batch_id, y, index] = 1
208
+ if index != 0 and (
209
+ index == y or log_likelihoods[batch_id, y - 1, index] < log_likelihoods[batch_id, y - 1, index - 1]
210
+ ):
211
+ index = index - 1
212
+ return path
213
+
214
+ #....................................
215
+
216
+ def slice_segments(self,hidden_states, ids_str, segment_size=4):
217
+
218
+ batch_size, channels, _ = hidden_states.shape
219
+ # 1d tensor containing the indices to keep
220
+ indices = torch.arange(segment_size).to(ids_str.device)
221
+ # extend the indices to match the shape of hidden_states
222
+ indices = indices.view(1, 1, -1).expand(batch_size, channels, -1)
223
+ # offset indices with ids_str
224
+ indices = indices + ids_str.view(-1, 1, 1)
225
+ # gather indices
226
+ output = torch.gather(hidden_states, dim=2, index=indices)
227
+
228
+ return output
229
+
230
+
231
+ #....................................
232
+
233
+
234
+ def rand_slice_segments(self,hidden_states, sample_lengths=None, segment_size=4):
235
+
236
+ batch_size, _, seq_len = hidden_states.size()
237
+ if sample_lengths is None:
238
+ sample_lengths = seq_len
239
+ ids_str_max = sample_lengths - segment_size + 1
240
+ ids_str = (torch.rand([batch_size]).to(device=hidden_states.device) * ids_str_max).to(dtype=torch.long)
241
+ ret = self.slice_segments(hidden_states, ids_str, segment_size)
242
+
243
+ return ret, ids_str
244
+
245
+ #....................................
246
+
247
+ def resize_speaker_embeddings(
248
+ self,
249
+ new_num_speakers: int,
250
+ speaker_embedding_size: Optional[int] = None,
251
+ pad_to_multiple_of: Optional[int] = 2,
252
+ ):
253
+ if pad_to_multiple_of is not None:
254
+ new_num_speakers = ((new_num_speakers + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
255
+
256
+ # first, take care of embed_speaker
257
+ if self.config.num_speakers <= 1:
258
+ if speaker_embedding_size is None:
259
+ raise ValueError(
260
+ "The current model had no previous speaker embedding, but `speaker_embedding_size` is not specified. Pass `speaker_embedding_size` to this method."
261
+ )
262
+ # create new embedding layer
263
+ new_embeddings = nn.Embedding(
264
+ new_num_speakers,
265
+ speaker_embedding_size,
266
+ device=self.device,
267
+ )
268
+ # initialize all new embeddings
269
+ self._init_weights(new_embeddings)
270
+ else:
271
+ new_embeddings = self._get_resized_embeddings(self.embed_speaker, new_num_speakers)
272
+
273
+ self.embed_speaker = new_embeddings
274
+
275
+ # then take care of sub-models
276
+ self.flow.resize_speaker_embeddings(speaker_embedding_size)
277
+ for flow in self.flow.flows:
278
+ self._init_weights(flow.wavenet.cond_layer)
279
+
280
+ self.decoder.resize_speaker_embedding(speaker_embedding_size)
281
+ self._init_weights(self.decoder.cond)
282
+
283
+ self.duration_predictor.resize_speaker_embeddings(speaker_embedding_size)
284
+ self._init_weights(self.duration_predictor.cond)
285
+
286
+ self.posterior_encoder.resize_speaker_embeddings(speaker_embedding_size)
287
+ self._init_weights(self.posterior_encoder.wavenet.cond_layer)
288
+
289
+ self.config.num_speakers = new_num_speakers
290
+ self.config.speaker_embedding_size = speaker_embedding_size
291
+
292
+ #....................................
293
+
294
+ def get_input_embeddings(self):
295
+ return self.text_encoder.get_input_embeddings()
296
+
297
+ #....................................
298
+
299
+ def set_input_embeddings(self, value):
300
+ self.text_encoder.set_input_embeddings(value)
301
+
302
+ #....................................
303
+
304
+ def apply_weight_norm(self):
305
+ self.decoder.apply_weight_norm()
306
+ self.flow.apply_weight_norm()
307
+ self.posterior_encoder.apply_weight_norm()
308
+
309
+ #....................................
310
+
311
+ def remove_weight_norm(self):
312
+ self.decoder.remove_weight_norm()
313
+ self.flow.remove_weight_norm()
314
+ self.posterior_encoder.remove_weight_norm()
315
+
316
+ #....................................
317
+
318
+ def discriminate(self, hidden_states):
319
+ return self.discriminator(hidden_states)
320
+
321
+ #....................................
322
+
323
+ def get_encoder(self):
324
+ return self.text_encoder
325
+
326
+ #....................................
327
+
328
+ def _inference_forward(
329
+ self,
330
+ input_ids: Optional[torch.Tensor] = None,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ speaker_embeddings: Optional[torch.Tensor] = None,
333
+ output_attentions: Optional[bool] = None,
334
+ output_hidden_states: Optional[bool] = None,
335
+ return_dict: Optional[bool] = None,
336
+ padding_mask: Optional[torch.Tensor] = None,
337
+ ):
338
+ text_encoder_output = self.text_encoder(
339
+ input_ids=input_ids,
340
+ padding_mask=padding_mask,
341
+ attention_mask=attention_mask,
342
+ output_attentions=output_attentions,
343
+ output_hidden_states=output_hidden_states,
344
+ return_dict=return_dict,
345
+ )
346
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
347
+ hidden_states = hidden_states.transpose(1, 2)
348
+ input_padding_mask = padding_mask.transpose(1, 2)
349
+
350
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
351
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
352
+
353
+ if self.config.use_stochastic_duration_prediction:
354
+ log_duration = self.duration_predictor(
355
+ hidden_states,
356
+ input_padding_mask,
357
+ speaker_embeddings,
358
+ reverse=True,
359
+ noise_scale=self.noise_scale_duration,
360
+ )
361
+ else:
362
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
363
+
364
+ length_scale = 1.0 / self.speaking_rate
365
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
366
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
367
+
368
+
369
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
370
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
371
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
372
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
373
+
374
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
375
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
376
+ batch_size, _, output_length, input_length = attn_mask.shape
377
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
378
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
379
+ valid_indices = indices.unsqueeze(0) < cum_duration
380
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
381
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
382
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
383
+
384
+ # Expand prior distribution
385
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
386
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
387
+
388
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
389
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
390
+
391
+ spectrogram = latents * output_padding_mask
392
+ waveform = self.decoder(spectrogram, speaker_embeddings)
393
+ waveform = waveform.squeeze(1)
394
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
395
+
396
+ if not return_dict:
397
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
398
+ return outputs
399
+
400
+ return VitsModelOutput(
401
+ waveform=waveform,
402
+ sequence_lengths=sequence_lengths,
403
+ spectrogram=spectrogram,
404
+ hidden_states=text_encoder_output.hidden_states,
405
+ attentions=text_encoder_output.attentions,
406
+ )
407
+
408
+ #....................................
409
+
410
+ def forward_k(
411
+ self,
412
+ input_ids: Optional[torch.Tensor] = None,
413
+ attention_mask: Optional[torch.Tensor] = None,
414
+ speaker_id: Optional[int] = None,
415
+ output_attentions: Optional[bool] = None,
416
+ output_hidden_states: Optional[bool] = None,
417
+ return_dict: Optional[bool] = None,
418
+ labels: Optional[torch.FloatTensor] = None,
419
+ labels_attention_mask: Optional[torch.Tensor] = None,
420
+ monotonic_alignment_function: Optional[Callable] = None,
421
+ ) -> Union[Tuple[Any], VitsModelOutput]:
422
+
423
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
424
+ output_hidden_states = (
425
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
426
+ )
427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
428
+
429
+ monotonic_alignment_function = (
430
+ self.monotonic_align_max_path if monotonic_alignment_function is None else monotonic_alignment_function
431
+ )
432
+
433
+ if attention_mask is not None:
434
+ input_padding_mask = attention_mask.unsqueeze(-1).float()
435
+ else:
436
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
437
+
438
+ if self.config.num_speakers > 1 and speaker_id is not None:
439
+ if isinstance(speaker_id, int):
440
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
441
+ elif isinstance(speaker_id, (list, tuple, np.ndarray)):
442
+ speaker_id = torch.tensor(speaker_id, device=self.device)
443
+
444
+ if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item():
445
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
446
+ if not (len(speaker_id) == 1 or len(speaker_id == len(input_ids))):
447
+ raise ValueError(
448
+ f"You passed {len(speaker_id)} `speaker_id` but you should either pass one speaker id or `batch_size` `speaker_id`."
449
+ )
450
+
451
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
452
+ else:
453
+ speaker_embeddings = None
454
+
455
+ # if inference, return inference forward of VitsModel
456
+ if labels is None:
457
+ return self._inference_forward(
458
+ input_ids,
459
+ attention_mask,
460
+ speaker_embeddings,
461
+ output_attentions,
462
+ output_hidden_states,
463
+ return_dict,
464
+ input_padding_mask,
465
+ )
466
+
467
+ if labels_attention_mask is not None:
468
+ labels_padding_mask = labels_attention_mask.unsqueeze(1).float()
469
+ else:
470
+ labels_attention_mask = torch.ones((labels.shape[0], labels.shape[2])).float().to(self.device)
471
+ labels_padding_mask = labels_attention_mask.unsqueeze(1)
472
+
473
+ text_encoder_output = self.text_encoder(
474
+ input_ids=input_ids,
475
+ padding_mask=input_padding_mask,
476
+ attention_mask=attention_mask,
477
+ output_attentions=output_attentions,
478
+ output_hidden_states=output_hidden_states,
479
+ return_dict=return_dict,
480
+ )
481
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
482
+ hidden_states = hidden_states.transpose(1, 2)
483
+ input_padding_mask = input_padding_mask.transpose(1, 2)
484
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
485
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
486
+
487
+ latents, posterior_means, posterior_log_variances = self.posterior_encoder(
488
+ labels, labels_padding_mask, speaker_embeddings
489
+ )
490
+ prior_latents = self.flow(latents, labels_padding_mask, speaker_embeddings, reverse=False)
491
+
492
+ prior_means, prior_log_variances = prior_means.transpose(1, 2), prior_log_variances.transpose(1, 2)
493
+ with torch.no_grad():
494
+ # negative cross-entropy
495
+
496
+ # [batch_size, d, latent_length]
497
+ prior_variances = torch.exp(-2 * prior_log_variances)
498
+ # [batch_size, 1, latent_length]
499
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - prior_log_variances, [1], keepdim=True)
500
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
501
+ neg_cent2 = torch.matmul(-0.5 * (prior_latents**2).transpose(1, 2), prior_variances)
502
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
503
+ neg_cent3 = torch.matmul(prior_latents.transpose(1, 2), (prior_means * prior_variances))
504
+ # [batch_size, 1, latent_length]
505
+ neg_cent4 = torch.sum(-0.5 * (prior_means**2) * prior_variances, [1], keepdim=True)
506
+
507
+ # [batch_size, text_length, latent_length]
508
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
509
+
510
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(labels_padding_mask, -1)
511
+
512
+ attn = monotonic_alignment_function(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
513
+
514
+ durations = attn.sum(2)
515
+
516
+ if self.config.use_stochastic_duration_prediction:
517
+ log_duration = self.duration_predictor(
518
+ hidden_states, input_padding_mask, speaker_embeddings, durations=durations, reverse=False
519
+ )
520
+ log_duration = log_duration / torch.sum(input_padding_mask)
521
+ else:
522
+ log_duration_padded = torch.log(durations + 1e-6) * input_padding_mask
523
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
524
+ log_duration = torch.sum((log_duration - log_duration_padded) ** 2, [1, 2]) / torch.sum(input_padding_mask)
525
+
526
+ # expand priors
527
+ prior_means = torch.matmul(attn.squeeze(1), prior_means.transpose(1, 2)).transpose(1, 2)
528
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances.transpose(1, 2)).transpose(1, 2)
529
+
530
+ label_lengths = labels_attention_mask.sum(dim=1)
531
+ latents_slice, ids_slice = self.rand_slice_segments(latents, label_lengths, segment_size=self.segment_size)
532
+
533
+ waveform = self.decoder(latents_slice, speaker_embeddings)
534
+
535
+ if not return_dict:
536
+ outputs = (
537
+ waveform,
538
+ log_duration,
539
+ attn,
540
+ ids_slice,
541
+ input_padding_mask,
542
+ labels_padding_mask,
543
+ latents,
544
+ prior_latents,
545
+ prior_means,
546
+ prior_log_variances,
547
+ posterior_means,
548
+ posterior_log_variances,
549
+ )
550
+ return outputs
551
+
552
+ return VitsTrainingOutput(
553
+ waveform=waveform,
554
+ log_duration=log_duration,
555
+ attn=attn,
556
+ ids_slice=ids_slice,
557
+ input_padding_mask=input_padding_mask,
558
+ labels_padding_mask=labels_padding_mask,
559
+ latents=latents,
560
+ prior_latents=prior_latents,
561
+ prior_means=prior_means,
562
+ prior_log_variances=prior_log_variances,
563
+ posterior_means=posterior_means,
564
+ posterior_log_variances=posterior_log_variances,
565
+ )
566
+
567
+
568
+
569
+ def forward(
570
+ self,
571
+ input_ids: Optional[torch.Tensor] = None,
572
+ attention_mask: Optional[torch.Tensor] = None,
573
+ speaker_id: Optional[int] = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ labels: Optional[torch.FloatTensor] = None,
578
+ labels_attention_mask: Optional[torch.Tensor] = None,
579
+ text_encoder_output=None,
580
+ posterior_encode_output=None,
581
+ monotonic_alignment_function: Optional[Callable] = None,
582
+ speaker_embeddings=None
583
+ ) -> Union[Tuple[Any], VitsModelOutput]:
584
+
585
+ #output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
586
+ output_hidden_states = (
587
+ output_hidden_states# if output_hidden_states is not None else self.config.output_hidden_states
588
+ )
589
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
590
+
591
+
592
+ # if attention_mask is not None:
593
+ input_padding_mask = attention_mask.unsqueeze(-1).float()
594
+ #else:
595
+ # input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
596
+
597
+ # speaker_embeddings=None
598
+ # if labels_attention_mask is not None:
599
+ labels_padding_mask = labels_attention_mask.unsqueeze(1).float()
600
+ # else:
601
+ # labels_attention_mask = torch.ones((labels.shape[0], labels.shape[2])).float().to(self.device)
602
+ # labels_padding_mask = labels_attention_mask.unsqueeze(1)
603
+ if text_encoder_output is None:
604
+ text_encoder_output = self.text_encoder(
605
+ input_ids=input_ids,
606
+ padding_mask=input_padding_mask,
607
+ attention_mask=attention_mask,
608
+ output_attentions=output_attentions,
609
+ output_hidden_states=output_hidden_states,
610
+ return_dict=return_dict,
611
+ )
612
+ #hidden_states = text_encoder_output[0] #if not return_dict else text_encoder_output.last_hidden_state
613
+ hidden_states = text_encoder_output[0].transpose(1, 2)
614
+ input_padding_mask = input_padding_mask.transpose(1, 2)
615
+ prior_means = text_encoder_output[1].transpose(1, 2) #if not return_dict else text_encoder_output.prior_means
616
+ prior_log_variances = text_encoder_output[2].transpose(1, 2) #if not return_dict else text_encoder_output.prior_log_variances
617
+
618
+ # if posterior_encode_output is None:
619
+ # latents, posterior_means, posterior_log_variances = self.posterior_encoder(
620
+ # labels, labels_padding_mask, speaker_embeddings
621
+ # )
622
+ # else:
623
+ latents=posterior_encode_output['posterior_latents']
624
+ posterior_means=posterior_encode_output['posterior_means']
625
+ posterior_log_variances=posterior_encode_output['posterior_log_variances']
626
+
627
+ prior_latents = self.flow(latents, labels_padding_mask, speaker_embeddings, reverse=False)
628
+
629
+ # prior_means, prior_log_variances = prior_means.transpose(1, 2), prior_log_variances.transpose(1, 2)
630
+ with torch.no_grad():
631
+ # negative cross-entropy
632
+
633
+ # [batch_size, d, latent_length]
634
+ prior_variances = torch.exp(-2 * prior_log_variances)
635
+ # [batch_size, 1, latent_length]
636
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - prior_log_variances, [1], keepdim=True)
637
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
638
+ neg_cent2 = torch.matmul(-0.5 * (prior_latents**2).transpose(1, 2), prior_variances)
639
+ # [batch_size, text_length, d] x [batch_size, d, latent_length] = [batch_size, text_length, latent_length]
640
+ neg_cent3 = torch.matmul(prior_latents.transpose(1, 2), (prior_means * prior_variances))
641
+ # [batch_size, 1, latent_length]
642
+ neg_cent4 = torch.sum(-0.5 * (prior_means**2) * prior_variances, [1], keepdim=True)
643
+
644
+ # [batch_size, text_length, latent_length]
645
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
646
+
647
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(labels_padding_mask, -1)
648
+
649
+ attn = monotonic_alignment_function(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
650
+
651
+ durations = attn.sum(2)
652
+
653
+ #if self.config.use_stochastic_duration_prediction:
654
+ log_duration = self.duration_predictor(
655
+ hidden_states, input_padding_mask, speaker_embeddings, durations=durations, reverse=False
656
+ )
657
+ log_duration = log_duration / torch.sum(input_padding_mask)
658
+ # else:
659
+ # log_duration_padded = torch.log(durations + 1e-6) * input_padding_mask
660
+ # log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
661
+ # log_duration = torch.sum((log_duration - log_duration_padded) ** 2, [1, 2]) / torch.sum(input_padding_mask)
662
+
663
+ # expand priors
664
+ prior_means = torch.matmul(attn.squeeze(1), prior_means.transpose(1, 2)).transpose(1, 2)
665
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances.transpose(1, 2)).transpose(1, 2)
666
+
667
+ label_lengths = labels_attention_mask.sum(dim=1)
668
+ latents_slice, ids_slice = self.rand_slice_segments(latents, label_lengths, segment_size=self.segment_size)
669
+ waveform = self.decoder(latents_slice, speaker_embeddings)
670
+ return waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask
VitsModelSplit/vits_output.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple, Union,List,Dict
2
+ import torch
3
+ from dataclasses import dataclass
4
+ from transformers.modeling_outputs import (
5
+ BaseModelOutput,
6
+ ModelOutput,
7
+ )
8
+ #.............................................
9
+
10
+
11
+
12
+ @dataclass
13
+ class PosteriorDecoderModelOutput(ModelOutput):
14
+ labels_padding_mask: torch.FloatTensor = None
15
+ posterior_latents: torch.FloatTensor = None
16
+ posterior_means: torch.FloatTensor = None
17
+ posterior_log_variances: torch.FloatTensor = None
18
+ latents_slice : torch.FloatTensor = None
19
+ ids_slice: torch.FloatTensor = None
20
+ waveform: torch.FloatTensor = None
21
+
22
+ #.............................................................................................
23
+
24
+
25
+ @dataclass
26
+ class VitsModelOutput(ModelOutput):
27
+ waveform: torch.FloatTensor = None
28
+ sequence_lengths: torch.FloatTensor = None
29
+ spectrogram: Optional[Tuple[torch.FloatTensor]] = None
30
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
31
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
32
+
33
+ #.............................................................................................
34
+
35
+ @dataclass
36
+ class VitsTrainingOutput(ModelOutput):
37
+ waveform: torch.FloatTensor = None
38
+ log_duration: torch.FloatTensor = None
39
+ attn: torch.FloatTensor = None
40
+ ids_slice: torch.FloatTensor = None
41
+ input_padding_mask: torch.FloatTensor = None
42
+ labels_padding_mask: torch.FloatTensor = None
43
+ latents: torch.FloatTensor = None
44
+ prior_latents: torch.FloatTensor = None
45
+ prior_means: torch.FloatTensor = None
46
+ prior_log_variances: torch.FloatTensor = None
47
+ posterior_means: torch.FloatTensor = None
48
+ posterior_log_variances: torch.FloatTensor = None
49
+
50
+
51
+ #.............................................................................................
52
+
53
+ @dataclass
54
+ class VitsTextEncoderOutput(ModelOutput):
55
+ """
56
+ Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
57
+
58
+ Args:
59
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
60
+ Sequence of hidden-states at the output of the last layer of the model.
61
+ prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
62
+ The predicted mean values of the prior distribution for the latent text variables.
63
+ prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
64
+ The predicted log-variance values of the prior distribution for the latent text variables.
65
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ """
77
+
78
+ last_hidden_state: torch.FloatTensor = None
79
+ prior_means: torch.FloatTensor = None
80
+ prior_log_variances: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
82
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
83
+
84
+ #.............................................................................................