Shivam Mehta commited on
Commit
3c10b34
1 Parent(s): f5a235a

Adding code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Makefile +30 -0
  2. app.py +253 -0
  3. diff_ttsg/__init__.py +0 -0
  4. diff_ttsg/__pycache__/__init__.cpython-310.pyc +0 -0
  5. diff_ttsg/data/__init__.py +0 -0
  6. diff_ttsg/data/__pycache__/__init__.cpython-310.pyc +0 -0
  7. diff_ttsg/data/__pycache__/cormac_datamodule.cpython-310.pyc +0 -0
  8. diff_ttsg/data/components/__init__.py +0 -0
  9. diff_ttsg/data/cormac_datamodule.py +214 -0
  10. diff_ttsg/data/mnist_datamodule.py +130 -0
  11. diff_ttsg/eval.py +93 -0
  12. diff_ttsg/hifigan/LICENSE +21 -0
  13. diff_ttsg/hifigan/README.md +105 -0
  14. diff_ttsg/hifigan/__init__.py +0 -0
  15. diff_ttsg/hifigan/__pycache__/__init__.cpython-310.pyc +0 -0
  16. diff_ttsg/hifigan/__pycache__/config.cpython-310.pyc +0 -0
  17. diff_ttsg/hifigan/__pycache__/denoiser.cpython-310.pyc +0 -0
  18. diff_ttsg/hifigan/__pycache__/env.cpython-310.pyc +0 -0
  19. diff_ttsg/hifigan/__pycache__/models.cpython-310.pyc +0 -0
  20. diff_ttsg/hifigan/__pycache__/xutils.cpython-310.pyc +0 -0
  21. diff_ttsg/hifigan/config.py +38 -0
  22. diff_ttsg/hifigan/denoiser.py +64 -0
  23. diff_ttsg/hifigan/env.py +17 -0
  24. diff_ttsg/hifigan/meldataset.py +171 -0
  25. diff_ttsg/hifigan/models.py +286 -0
  26. diff_ttsg/hifigan/xutils.py +60 -0
  27. diff_ttsg/models/__init__.py +0 -0
  28. diff_ttsg/models/__pycache__/__init__.cpython-310.pyc +0 -0
  29. diff_ttsg/models/__pycache__/diff_ttsg.cpython-310.pyc +0 -0
  30. diff_ttsg/models/components/__init__.py +0 -0
  31. diff_ttsg/models/components/__pycache__/__init__.cpython-310.pyc +0 -0
  32. diff_ttsg/models/components/__pycache__/diffusion.cpython-310.pyc +0 -0
  33. diff_ttsg/models/components/__pycache__/text_encoder.cpython-310.pyc +0 -0
  34. diff_ttsg/models/components/__pycache__/transformer.cpython-310.pyc +0 -0
  35. diff_ttsg/models/components/diffusion.py +376 -0
  36. diff_ttsg/models/components/text_encoder.py +384 -0
  37. diff_ttsg/models/components/transformer.py +250 -0
  38. diff_ttsg/models/diff_ttsg.py +376 -0
  39. diff_ttsg/models/mnist_module.py +137 -0
  40. diff_ttsg/resources/cmu_dictionary +0 -0
  41. diff_ttsg/text/LICENSE +30 -0
  42. diff_ttsg/text/__init__.py +96 -0
  43. diff_ttsg/text/__pycache__/__init__.cpython-310.pyc +0 -0
  44. diff_ttsg/text/__pycache__/cleaners.cpython-310.pyc +0 -0
  45. diff_ttsg/text/__pycache__/cmudict.cpython-310.pyc +0 -0
  46. diff_ttsg/text/__pycache__/numbers.cpython-310.pyc +0 -0
  47. diff_ttsg/text/__pycache__/symbols.cpython-310.pyc +0 -0
  48. diff_ttsg/text/cleaners.py +73 -0
  49. diff_ttsg/text/cmudict.py +60 -0
  50. diff_ttsg/text/numbers.py +72 -0
Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ help: ## Show help
3
+ @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4
+
5
+ clean: ## Clean autogenerated files
6
+ rm -rf dist
7
+ find . -type f -name "*.DS_Store" -ls -delete
8
+ find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9
+ find . | grep -E ".pytest_cache" | xargs rm -rf
10
+ find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11
+ rm -f .coverage
12
+
13
+ clean-logs: ## Clean logs
14
+ rm -rf logs/**
15
+
16
+ format: ## Run pre-commit hooks
17
+ pre-commit run -a
18
+
19
+ sync: ## Merge changes from main branch to your current branch
20
+ git pull
21
+ git pull origin main
22
+
23
+ test: ## Run not slow tests
24
+ pytest -k "not slow"
25
+
26
+ test-full: ## Run all tests
27
+ pytest
28
+
29
+ train: ## Train the model
30
+ python diff_ttsg/train.py run_name=dev
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime as dt
3
+ import warnings
4
+ from pathlib import Path
5
+
6
+ import ffmpeg
7
+ import gradio as gr
8
+ import IPython.display as ipd
9
+ import joblib as jl
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torch
13
+ from tqdm.auto import tqdm
14
+
15
+ from diff_ttsg.hifigan.config import v1
16
+ from diff_ttsg.hifigan.denoiser import Denoiser
17
+ from diff_ttsg.hifigan.env import AttrDict
18
+ from diff_ttsg.hifigan.models import Generator as HiFiGAN
19
+ from diff_ttsg.models.diff_ttsg import Diff_TTSG
20
+ from diff_ttsg.text import cmudict, sequence_to_text, text_to_sequence
21
+ from diff_ttsg.text.symbols import symbols
22
+ from diff_ttsg.utils.model import denormalize
23
+ from diff_ttsg.utils.utils import intersperse, plot_tensor
24
+ from pymo.preprocessing import MocapParameterizer
25
+ from pymo.viz_tools import render_mp4
26
+ from pymo.writers import BVHWriter
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ DIFF_TTSG_CHECKPOINT = "diff_ttsg_checkpoint.ckpt"
31
+ HIFIGAN_CHECKPOINT = "g_02500000"
32
+ MOTION_PIPELINE = "diff_ttsg/resources/data_pipe.expmap_86.1328125fps.sav"
33
+ CMU_DICT_PATH = "diff_ttsg/resources/cmu_dictionary"
34
+
35
+ OUTPUT_FOLDER = "synth_output"
36
+
37
+ # Model loading tools
38
+ def load_model(checkpoint_path):
39
+ model = Diff_TTSG.load_from_checkpoint(checkpoint_path, map_location=device)
40
+ model.eval()
41
+ return model
42
+
43
+ # Vocoder loading tools
44
+ def load_vocoder(checkpoint_path):
45
+ h = AttrDict(v1)
46
+ hifigan = HiFiGAN(h).to(device)
47
+ hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
48
+ _ = hifigan.eval()
49
+ hifigan.remove_weight_norm()
50
+ return hifigan
51
+
52
+ # Setup text preprocessing
53
+ cmu = cmudict.CMUDict(CMU_DICT_PATH)
54
+ def process_text(text: str):
55
+ x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols))).to(device)[None]
56
+ x_lengths = torch.LongTensor([x.shape[-1]]).cuda()
57
+ x_phones = sequence_to_text(x.squeeze(0).tolist())
58
+ return {
59
+ 'x_orig': text,
60
+ 'x': x,
61
+ 'x_lengths': x_lengths,
62
+ 'x_phones': x_phones
63
+ }
64
+
65
+ # Setup motion visualisation
66
+ motion_pipeline = jl.load(MOTION_PIPELINE)
67
+ bvh_writer = BVHWriter()
68
+ mocap_params = MocapParameterizer("position")
69
+
70
+
71
+
72
+ ## Load models
73
+
74
+ model = load_model(DIFF_TTSG_CHECKPOINT)
75
+ vocoder = load_vocoder(HIFIGAN_CHECKPOINT)
76
+ denoiser = Denoiser(vocoder, mode='zeros')
77
+
78
+
79
+ # Synthesis functions
80
+
81
+ @torch.inference_mode()
82
+ def synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp):
83
+
84
+ ## Number of timesteps to run the reverse denoising process
85
+ n_timesteps = {
86
+ 'mel': mel_timestep,
87
+ 'motion': motion_timestep,
88
+ }
89
+
90
+ ## Sampling temperature
91
+ temperature = {
92
+ 'mel': mel_temp,
93
+ 'motion': motion_temp
94
+ }
95
+ text_processed = process_text(text)
96
+ t = dt.datetime.now()
97
+ output = model.synthesise(
98
+ text_processed['x'],
99
+ text_processed['x_lengths'],
100
+ n_timesteps=n_timesteps,
101
+ temperature=temperature,
102
+ stoc=False,
103
+ spk=None,
104
+ length_scale=length_scale
105
+ )
106
+
107
+ t = (dt.datetime.now() - t).total_seconds()
108
+ print(f'RTF: {t * 22050 / (output["mel"].shape[-1] * 256)}')
109
+
110
+ output.update(text_processed) # merge everything to one dict
111
+ return output
112
+
113
+ @torch.inference_mode()
114
+ def to_waveform(mel, vocoder):
115
+ audio = vocoder(mel).clamp(-1, 1)
116
+ audio = denoiser(audio.squeeze(0)).cpu().squeeze()
117
+ return audio
118
+
119
+
120
+ def to_bvh(motion):
121
+ with warnings.catch_warnings():
122
+ warnings.simplefilter("ignore")
123
+ return motion_pipeline.inverse_transform([motion.cpu().squeeze(0).T])
124
+
125
+
126
+ def save_to_folder(filename: str, output: dict, folder: str):
127
+ folder = Path(folder)
128
+ folder.mkdir(exist_ok=True, parents=True)
129
+ np.save(folder / f'{filename}', output['mel'].cpu().numpy())
130
+ sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')
131
+ with open(folder / f'{filename}.bvh', 'w') as f:
132
+ bvh_writer.write(output['bvh'], f)
133
+
134
+
135
+ def to_stick_video(filename, bvh, folder):
136
+ folder = Path(folder)
137
+ folder.mkdir(exist_ok=True, parents=True)
138
+
139
+ with warnings.catch_warnings():
140
+ warnings.simplefilter("ignore")
141
+ X_pos = mocap_params.fit_transform([bvh])
142
+ print(f"rendering {filename} ...")
143
+ render_mp4(X_pos[0], folder / f'{filename}.mp4', axis_scale=200)
144
+
145
+
146
+ def combine_audio_video(filename: str, folder: str):
147
+ print("Combining audio and video")
148
+ folder = Path(folder)
149
+ folder.mkdir(exist_ok=True, parents=True)
150
+
151
+ input_video = ffmpeg.input(str(folder / f'{filename}.mp4'))
152
+ input_audio = ffmpeg.input(str(folder / f'{filename}.wav'))
153
+ output_filename = folder / f'{filename}_audio.mp4'
154
+ ffmpeg.concat(input_video, input_audio, v=1, a=1).output(str(output_filename)).run(overwrite_output=True)
155
+ print(f"Final output with audio: {output_filename}")
156
+
157
+
158
+ def run(text, output, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp):
159
+ print("Running synthesis")
160
+ output = synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp)
161
+ output['waveform'] = to_waveform(output['mel'], vocoder)
162
+ output['bvh'] = to_bvh(output['motion'])[0]
163
+ save_to_folder('temp', output, OUTPUT_FOLDER)
164
+ return (
165
+ output,
166
+ output['x_phones'],
167
+ plot_tensor(output['mel'].squeeze().cpu().numpy()),
168
+ plot_tensor(output['motion'].squeeze().cpu().numpy()),
169
+ str(Path(OUTPUT_FOLDER) / f'temp.wav'),
170
+ gr.update(interactive=True)
171
+ )
172
+
173
+ def visualize_it(output):
174
+ to_stick_video('temp', output['bvh'], OUTPUT_FOLDER)
175
+ combine_audio_video('temp', OUTPUT_FOLDER)
176
+ return str(Path(OUTPUT_FOLDER) / 'temp_audio.mp4')
177
+
178
+
179
+ with gr.Blocks() as demo:
180
+
181
+ output = gr.State(value=None)
182
+
183
+ with gr.Row():
184
+ gr.Markdown("# Text Input")
185
+ with gr.Row():
186
+ text = gr.Textbox(label="Text Input")
187
+
188
+ with gr.Box():
189
+ with gr.Row():
190
+ gr.Markdown("### Hyper parameters")
191
+ with gr.Row():
192
+ mel_timestep = gr.Slider(label="Number of timesteps (mel)", minimum=0, maximum=1000, step=1, value=50, interactive=True)
193
+ motion_timestep = gr.Slider(label="Number of timesteps (motion)", minimum=0, maximum=1000, step=1, value=500, interactive=True)
194
+ length_scale = gr.Slider(label="Length scale (Speaking rate)", minimum=0.01, maximum=3.0, step=0.05, value=1.15, interactive=True)
195
+ mel_temp = gr.Slider(label="Sampling temperature (mel)", minimum=0.01, maximum=5.0, step=0.05, value=1.3, interactive=True)
196
+ motion_temp = gr.Slider(label="Sampling temperature (motion)", minimum=0.01, maximum=5.0, step=0.05, value=1.5, interactive=True)
197
+
198
+ synth_btn = gr.Button("Synthesise")
199
+
200
+ with gr.Box():
201
+ with gr.Row():
202
+ gr.Markdown("### Phonetised text")
203
+ with gr.Row():
204
+ phonetised_text = gr.Textbox(label="Phonetised text", interactive=False)
205
+
206
+ with gr.Box():
207
+ with gr.Row():
208
+ mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
209
+ motion_representation = gr.Image(interactive=False, label="Motion representation")
210
+
211
+ with gr.Row():
212
+ audio = gr.Audio(interactive=False, label="Audio")
213
+
214
+ with gr.Box():
215
+ with gr.Row():
216
+ gr.Markdown("### Generate stick figure visualisation")
217
+ with gr.Row():
218
+ gr.Markdown("(This will take a while)")
219
+ with gr.Row():
220
+ visualize = gr.Button("Visualize", interactive=False)
221
+
222
+ with gr.Row():
223
+ video = gr.Video(label="Video", interactive=False)
224
+
225
+ synth_btn.click(
226
+ fn=run,
227
+ inputs=[
228
+ text,
229
+ output,
230
+ mel_timestep,
231
+ motion_timestep,
232
+ length_scale,
233
+ mel_temp,
234
+ motion_temp
235
+ ],
236
+ outputs=[
237
+ output,
238
+ phonetised_text,
239
+ mel_spectrogram,
240
+ motion_representation,
241
+ audio,
242
+ # video,
243
+ visualize
244
+ ], api_name="diff_ttsg")
245
+
246
+ visualize.click(
247
+ fn=visualize_it,
248
+ inputs=[output],
249
+ outputs=[video],
250
+ )
251
+
252
+ demo.queue(1)
253
+ demo.launch()
diff_ttsg/__init__.py ADDED
File without changes
diff_ttsg/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
diff_ttsg/data/__init__.py ADDED
File without changes
diff_ttsg/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
diff_ttsg/data/__pycache__/cormac_datamodule.cpython-310.pyc ADDED
Binary file (7.29 kB). View file
 
diff_ttsg/data/components/__init__.py ADDED
File without changes
diff_ttsg/data/cormac_datamodule.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchaudio as ta
10
+ from einops import pack
11
+ from lightning import LightningDataModule
12
+ from torch.utils.data.dataloader import DataLoader
13
+
14
+ from diff_ttsg.text import cmudict, text_to_sequence
15
+ from diff_ttsg.text.symbols import symbols
16
+ from diff_ttsg.utils.audio import mel_spectrogram
17
+ from diff_ttsg.utils.model import fix_len_compatibility, normalize
18
+ from diff_ttsg.utils.utils import intersperse, parse_filelist
19
+
20
+
21
+ class CormacDataModule(LightningDataModule):
22
+
23
+ def __init__(
24
+ self,
25
+ train_filelist_path,
26
+ valid_filelist_path,
27
+ batch_size,
28
+ num_workers,
29
+ pin_memory,
30
+ cmudict_path,
31
+ motion_folder,
32
+ add_blank,
33
+ n_fft,
34
+ n_feats,
35
+ sample_rate,
36
+ hop_length,
37
+ win_length,
38
+ f_min,
39
+ f_max,
40
+ data_statistics,
41
+ motion_pipeline_filename,
42
+ seed
43
+ ):
44
+ super().__init__()
45
+
46
+ # this line allows to access init params with 'self.hparams' attribute
47
+ # also ensures init params will be stored in ckpt
48
+ self.save_hyperparameters(logger=False)
49
+
50
+ def setup(self, stage: Optional[str] = None):
51
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
52
+
53
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
54
+ careful not to execute things like random split twice!
55
+ """
56
+ # load and split datasets only if not loaded already
57
+
58
+ self.trainset = TextMelDataset(
59
+ self.hparams.train_filelist_path,
60
+ self.hparams.cmudict_path,
61
+ self.hparams.motion_folder,
62
+ self.hparams.add_blank,
63
+ self.hparams.n_fft,
64
+ self.hparams.n_feats,
65
+ self.hparams.sample_rate,
66
+ self.hparams.hop_length,
67
+ self.hparams.win_length,
68
+ self.hparams.f_min,
69
+ self.hparams.f_max,
70
+ self.hparams.data_statistics,
71
+ self.hparams.seed
72
+ )
73
+ self.validset = TextMelDataset(
74
+ self.hparams.valid_filelist_path,
75
+ self.hparams.cmudict_path,
76
+ self.hparams.motion_folder,
77
+ self.hparams.add_blank,
78
+ self.hparams.n_fft,
79
+ self.hparams.n_feats,
80
+ self.hparams.sample_rate,
81
+ self.hparams.hop_length,
82
+ self.hparams.win_length,
83
+ self.hparams.f_min,
84
+ self.hparams.f_max,
85
+ self.hparams.data_statistics,
86
+ self.hparams.seed
87
+ )
88
+
89
+
90
+ def train_dataloader(self):
91
+ return DataLoader(
92
+ dataset=self.trainset,
93
+ batch_size=self.hparams.batch_size,
94
+ num_workers=self.hparams.num_workers,
95
+ pin_memory=self.hparams.pin_memory,
96
+ shuffle=True,
97
+ collate_fn=TextMelBatchCollate()
98
+ )
99
+
100
+ def val_dataloader(self):
101
+ return DataLoader(
102
+ dataset=self.validset,
103
+ batch_size=self.hparams.batch_size,
104
+ num_workers=self.hparams.num_workers,
105
+ pin_memory=self.hparams.pin_memory,
106
+ shuffle=False,
107
+ collate_fn=TextMelBatchCollate()
108
+ )
109
+
110
+ def teardown(self, stage: Optional[str] = None):
111
+ """Clean up after fit or test."""
112
+ pass
113
+
114
+ def state_dict(self):
115
+ """Extra things to save to checkpoint."""
116
+ return {}
117
+
118
+ def load_state_dict(self, state_dict: Dict[str, Any]):
119
+ """Things to do when loading checkpoint."""
120
+ pass
121
+
122
+
123
+ class TextMelDataset(torch.utils.data.Dataset):
124
+ def __init__(self, filelist_path, cmudict_path, motion_folder, add_blank=True,
125
+ n_fft=1024, n_mels=80, sample_rate=22050,
126
+ hop_length=256, win_length=1024, f_min=0., f_max=8000, data_parameters=None, seed=None):
127
+ self.filepaths_and_text = parse_filelist(filelist_path)
128
+ self.motion_fileloc = Path(motion_folder)
129
+ self.cmudict = cmudict.CMUDict(cmudict_path)
130
+ self.add_blank = add_blank
131
+ self.n_fft = n_fft
132
+ self.n_mels = n_mels
133
+ self.sample_rate = sample_rate
134
+ self.hop_length = hop_length
135
+ self.win_length = win_length
136
+ self.f_min = f_min
137
+ self.f_max = f_max
138
+ if data_parameters is not None:
139
+ self.data_parameters = data_parameters
140
+ else:
141
+ self.data_parameters = { 'mel_mean': 0, 'mel_std': 1, 'motion_mean': 0, 'motion_std': 1 }
142
+ random.seed(seed)
143
+ random.shuffle(self.filepaths_and_text)
144
+
145
+ def get_pair(self, filepath_and_text):
146
+ filepath, text = filepath_and_text[0], filepath_and_text[1]
147
+ text = self.get_text(text, add_blank=self.add_blank)
148
+ mel = self.get_mel(filepath)
149
+ motion = self.get_motion(filepath, mel.shape[1])
150
+ return (text, mel, motion)
151
+
152
+ def get_motion(self, filename, mel_shape, ext=".expmap_86.1328125fps.pkl"):
153
+ file_loc = self.motion_fileloc / Path(Path(filename).name).with_suffix(ext)
154
+ motion = torch.from_numpy(pd.read_pickle(file_loc).to_numpy())
155
+ motion = F.interpolate(motion.T.unsqueeze(0), mel_shape).squeeze(0)
156
+ motion = normalize(motion, self.data_parameters['motion_mean'], self.data_parameters['motion_std'])
157
+ return motion
158
+
159
+ def get_mel(self, filepath):
160
+ audio, sr = ta.load(filepath)
161
+ assert sr == self.sample_rate
162
+ mel = mel_spectrogram(audio, self.n_fft, 80, self.sample_rate, self.hop_length,
163
+ self.win_length, self.f_min, self.f_max, center=False).squeeze()
164
+ mel = normalize(mel, self.data_parameters['mel_mean'], self.data_parameters['mel_std'])
165
+ return mel
166
+
167
+ def get_text(self, text, add_blank=True):
168
+ text_norm = text_to_sequence(text, dictionary=self.cmudict)
169
+ if self.add_blank:
170
+ text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
171
+ text_norm = torch.IntTensor(text_norm)
172
+ return text_norm
173
+
174
+ def __getitem__(self, index):
175
+ text, mel, motion = self.get_pair(self.filepaths_and_text[index])
176
+ item = {'y': mel, 'x': text, 'y_motion': motion}
177
+ return item
178
+
179
+ def __len__(self):
180
+ return len(self.filepaths_and_text)
181
+
182
+ def sample_test_batch(self, size):
183
+ idx = np.random.choice(range(len(self)), size=size, replace=False)
184
+ test_batch = []
185
+ for index in idx:
186
+ test_batch.append(self.__getitem__(index))
187
+ return test_batch
188
+
189
+
190
+ class TextMelBatchCollate(object):
191
+ def __call__(self, batch):
192
+ B = len(batch)
193
+ y_max_length = max([item['y'].shape[-1] for item in batch])
194
+ y_max_length = fix_len_compatibility(y_max_length)
195
+ x_max_length = max([item['x'].shape[-1] for item in batch])
196
+ n_feats = batch[0]['y'].shape[-2]
197
+ n_motion = batch[0]['y_motion'].shape[-2]
198
+
199
+ y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
200
+ x = torch.zeros((B, x_max_length), dtype=torch.long)
201
+ y_motion = torch.zeros((B, n_motion, y_max_length), dtype=torch.float32)
202
+ y_lengths, x_lengths = [], []
203
+
204
+ for i, item in enumerate(batch):
205
+ y_, x_, y_motion_ = item['y'], item['x'], item['y_motion']
206
+ y_lengths.append(y_.shape[-1])
207
+ x_lengths.append(x_.shape[-1])
208
+ y[i, :, :y_.shape[-1]] = y_
209
+ x[i, :x_.shape[-1]] = x_
210
+ y_motion[i, :, :y_motion_.shape[-1]] = y_motion_
211
+
212
+ y_lengths = torch.LongTensor(y_lengths)
213
+ x_lengths = torch.LongTensor(x_lengths)
214
+ return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths, 'y_motion': y_motion}
diff_ttsg/data/mnist_datamodule.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import torch
4
+ from lightning import LightningDataModule
5
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
6
+ from torchvision.datasets import MNIST
7
+ from torchvision.transforms import transforms
8
+
9
+
10
+ class MNISTDataModule(LightningDataModule):
11
+ """Example of LightningDataModule for MNIST dataset.
12
+
13
+ A DataModule implements 6 key methods:
14
+ def prepare_data(self):
15
+ # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
16
+ # download data, pre-process, split, save to disk, etc...
17
+ def setup(self, stage):
18
+ # things to do on every process in DDP
19
+ # load data, set variables, etc...
20
+ def train_dataloader(self):
21
+ # return train dataloader
22
+ def val_dataloader(self):
23
+ # return validation dataloader
24
+ def test_dataloader(self):
25
+ # return test dataloader
26
+ def teardown(self):
27
+ # called on every process in DDP
28
+ # clean up after fit or test
29
+
30
+ This allows you to share a full dataset without explaining how to download,
31
+ split, transform and process the data.
32
+
33
+ Read the docs:
34
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ data_dir: str = "data/",
40
+ train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
41
+ batch_size: int = 64,
42
+ num_workers: int = 0,
43
+ pin_memory: bool = False,
44
+ ):
45
+ super().__init__()
46
+
47
+ # this line allows to access init params with 'self.hparams' attribute
48
+ # also ensures init params will be stored in ckpt
49
+ self.save_hyperparameters(logger=False)
50
+
51
+ # data transformations
52
+ self.transforms = transforms.Compose(
53
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
54
+ )
55
+
56
+ self.data_train: Optional[Dataset] = None
57
+ self.data_val: Optional[Dataset] = None
58
+ self.data_test: Optional[Dataset] = None
59
+
60
+ @property
61
+ def num_classes(self):
62
+ return 10
63
+
64
+ def prepare_data(self):
65
+ """Download data if needed.
66
+
67
+ Do not use it to assign state (self.x = y).
68
+ """
69
+ MNIST(self.hparams.data_dir, train=True, download=True)
70
+ MNIST(self.hparams.data_dir, train=False, download=True)
71
+
72
+ def setup(self, stage: Optional[str] = None):
73
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
74
+
75
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
76
+ careful not to execute things like random split twice!
77
+ """
78
+ # load and split datasets only if not loaded already
79
+ if not self.data_train and not self.data_val and not self.data_test:
80
+ trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
81
+ testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
82
+ dataset = ConcatDataset(datasets=[trainset, testset])
83
+ self.data_train, self.data_val, self.data_test = random_split(
84
+ dataset=dataset,
85
+ lengths=self.hparams.train_val_test_split,
86
+ generator=torch.Generator().manual_seed(42),
87
+ )
88
+
89
+ def train_dataloader(self):
90
+ return DataLoader(
91
+ dataset=self.data_train,
92
+ batch_size=self.hparams.batch_size,
93
+ num_workers=self.hparams.num_workers,
94
+ pin_memory=self.hparams.pin_memory,
95
+ shuffle=True,
96
+ )
97
+
98
+ def val_dataloader(self):
99
+ return DataLoader(
100
+ dataset=self.data_val,
101
+ batch_size=self.hparams.batch_size,
102
+ num_workers=self.hparams.num_workers,
103
+ pin_memory=self.hparams.pin_memory,
104
+ shuffle=False,
105
+ )
106
+
107
+ def test_dataloader(self):
108
+ return DataLoader(
109
+ dataset=self.data_test,
110
+ batch_size=self.hparams.batch_size,
111
+ num_workers=self.hparams.num_workers,
112
+ pin_memory=self.hparams.pin_memory,
113
+ shuffle=False,
114
+ )
115
+
116
+ def teardown(self, stage: Optional[str] = None):
117
+ """Clean up after fit or test."""
118
+ pass
119
+
120
+ def state_dict(self):
121
+ """Extra things to save to checkpoint."""
122
+ return {}
123
+
124
+ def load_state_dict(self, state_dict: Dict[str, Any]):
125
+ """Things to do when loading checkpoint."""
126
+ pass
127
+
128
+
129
+ if __name__ == "__main__":
130
+ _ = MNISTDataModule()
diff_ttsg/eval.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import hydra
4
+ import pyrootutils
5
+ from lightning import LightningDataModule, LightningModule, Trainer
6
+ from lightning.pytorch.loggers import Logger
7
+ from omegaconf import DictConfig
8
+
9
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+ # ------------------------------------------------------------------------------------ #
11
+ # the setup_root above is equivalent to:
12
+ # - adding project root dir to PYTHONPATH
13
+ # (so you don't need to force user to install project as a package)
14
+ # (necessary before importing any local modules e.g. `from src import utils`)
15
+ # - setting up PROJECT_ROOT environment variable
16
+ # (which is used as a base for paths in "configs/paths/default.yaml")
17
+ # (this way all filepaths are the same no matter where you run the code)
18
+ # - loading environment variables from ".env" in root dir
19
+ #
20
+ # you can remove it if you:
21
+ # 1. either install project as a package or move entry files to project root dir
22
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
23
+ #
24
+ # more info: https://github.com/ashleve/pyrootutils
25
+ # ------------------------------------------------------------------------------------ #
26
+
27
+ from diff_ttsg import utils
28
+
29
+ log = utils.get_pylogger(__name__)
30
+
31
+
32
+ @utils.task_wrapper
33
+ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
34
+ """Evaluates given checkpoint on a datamodule testset.
35
+
36
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
37
+ failure. Useful for multiruns, saving info about the crash, etc.
38
+
39
+ Args:
40
+ cfg (DictConfig): Configuration composed by Hydra.
41
+
42
+ Returns:
43
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
44
+ """
45
+
46
+ assert cfg.ckpt_path
47
+
48
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
49
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
50
+
51
+ log.info(f"Instantiating model <{cfg.model._target_}>")
52
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
53
+
54
+ log.info("Instantiating loggers...")
55
+ logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
56
+
57
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
58
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
59
+
60
+ object_dict = {
61
+ "cfg": cfg,
62
+ "datamodule": datamodule,
63
+ "model": model,
64
+ "logger": logger,
65
+ "trainer": trainer,
66
+ }
67
+
68
+ if logger:
69
+ log.info("Logging hyperparameters!")
70
+ utils.log_hyperparameters(object_dict)
71
+
72
+ log.info("Starting testing!")
73
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
74
+
75
+ # for predictions use trainer.predict(...)
76
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
77
+
78
+ metric_dict = trainer.callback_metrics
79
+
80
+ return metric_dict, object_dict
81
+
82
+
83
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
84
+ def main(cfg: DictConfig) -> None:
85
+ # apply extra utilities
86
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
87
+ utils.extras(cfg)
88
+
89
+ evaluate(cfg)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
diff_ttsg/hifigan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
diff_ttsg/hifigan/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2
+
3
+ ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4
+
5
+ In our [paper](https://arxiv.org/abs/2010.05646),
6
+ we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
7
+ We provide our implementation and pretrained models as open source in this repository.
8
+
9
+ **Abstract :**
10
+ Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11
+ Although such methods improve the sampling efficiency and memory usage,
12
+ their sample quality has not yet reached that of autoregressive and flow-based generative models.
13
+ In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14
+ As speech audio consists of sinusoidal signals with various periods,
15
+ we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16
+ A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17
+ demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18
+ real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19
+ speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20
+ faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21
+
22
+ Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23
+
24
+
25
+ ## Pre-requisites
26
+ 1. Python >= 3.6
27
+ 2. Clone this repository.
28
+ 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29
+ 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30
+ And move all wav files to `LJSpeech-1.1/wavs`
31
+
32
+
33
+ ## Training
34
+ ```
35
+ python train.py --config config_v1.json
36
+ ```
37
+ To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
38
+ Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
39
+ You can change the path by adding `--checkpoint_path` option.
40
+
41
+ Validation loss during training with V1 generator.<br>
42
+ ![validation loss](./validation_loss.png)
43
+
44
+ ## Pretrained Model
45
+ You can also use pretrained models we provide.<br/>
46
+ [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
47
+ Details of each folder are as in follows:
48
+
49
+ |Folder Name|Generator|Dataset|Fine-Tuned|
50
+ |------|---|---|---|
51
+ |LJ_V1|V1|LJSpeech|No|
52
+ |LJ_V2|V2|LJSpeech|No|
53
+ |LJ_V3|V3|LJSpeech|No|
54
+ |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
55
+ |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
56
+ |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
57
+ |VCTK_V1|V1|VCTK|No|
58
+ |VCTK_V2|V2|VCTK|No|
59
+ |VCTK_V3|V3|VCTK|No|
60
+ |UNIVERSAL_V1|V1|Universal|No|
61
+
62
+ We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
63
+
64
+ ## Fine-Tuning
65
+ 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
66
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
67
+ Example:
68
+ ```
69
+ Audio File : LJ001-0001.wav
70
+ Mel-Spectrogram File : LJ001-0001.npy
71
+ ```
72
+ 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
73
+ 3. Run the following command.
74
+ ```
75
+ python train.py --fine_tuning True --config config_v1.json
76
+ ```
77
+ For other command line options, please refer to the training section.
78
+
79
+
80
+ ## Inference from wav file
81
+ 1. Make `test_files` directory and copy wav files into the directory.
82
+ 2. Run the following command.
83
+ ```
84
+ python inference.py --checkpoint_file [generator checkpoint file path]
85
+ ```
86
+ Generated wav files are saved in `generated_files` by default.<br>
87
+ You can change the path by adding `--output_dir` option.
88
+
89
+
90
+ ## Inference for end-to-end speech synthesis
91
+ 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
92
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
93
+ [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
94
+ 2. Run the following command.
95
+ ```
96
+ python inference_e2e.py --checkpoint_file [generator checkpoint file path]
97
+ ```
98
+ Generated wav files are saved in `generated_files_from_mel` by default.<br>
99
+ You can change the path by adding `--output_dir` option.
100
+
101
+
102
+ ## Acknowledgements
103
+ We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
104
+ and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
105
+
diff_ttsg/hifigan/__init__.py ADDED
File without changes
diff_ttsg/hifigan/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
diff_ttsg/hifigan/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.02 kB). View file
 
diff_ttsg/hifigan/__pycache__/denoiser.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
diff_ttsg/hifigan/__pycache__/env.cpython-310.pyc ADDED
Binary file (883 Bytes). View file
 
diff_ttsg/hifigan/__pycache__/models.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
diff_ttsg/hifigan/__pycache__/xutils.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
diff_ttsg/hifigan/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ v1 = {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0004,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+ "resblock_initial_channel": 256,
17
+
18
+ "segment_size": 8192,
19
+ "num_mels": 80,
20
+ "num_freq": 1025,
21
+ "n_fft": 1024,
22
+ "hop_size": 256,
23
+ "win_size": 1024,
24
+
25
+ "sampling_rate": 22050,
26
+
27
+ "fmin": 0,
28
+ "fmax": 8000,
29
+ "fmax_loss": None,
30
+
31
+ "num_workers": 4,
32
+
33
+ "dist_config": {
34
+ "dist_backend": "nccl",
35
+ "dist_url": "tcp://localhost:54321",
36
+ "world_size": 1
37
+ }
38
+ }
diff_ttsg/hifigan/denoiser.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2
+
3
+ """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4
+ import torch
5
+
6
+
7
+ class Denoiser(torch.nn.Module):
8
+ """Removes model bias from audio produced with waveglow"""
9
+
10
+ def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
11
+ super().__init__()
12
+ self.filter_length = filter_length
13
+ self.hop_length = int(filter_length / n_overlap)
14
+ self.win_length = win_length
15
+
16
+ dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
17
+ self.device = device
18
+ if mode == "zeros":
19
+ mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
20
+ elif mode == "normal":
21
+ mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
22
+ else:
23
+ raise Exception(f"Mode {mode} if not supported")
24
+
25
+ def stft_fn(audio, n_fft, hop_length, win_length, window):
26
+ spec = torch.stft(
27
+ audio,
28
+ n_fft=n_fft,
29
+ hop_length=hop_length,
30
+ win_length=win_length,
31
+ window=window,
32
+ return_complex=True,
33
+ )
34
+ spec = torch.view_as_real(spec)
35
+ return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
36
+
37
+ self.stft = lambda x : stft_fn(
38
+ audio=x,
39
+ n_fft=self.filter_length,
40
+ hop_length=self.hop_length,
41
+ win_length=self.win_length,
42
+ window=torch.hann_window(self.win_length, device=device)
43
+ )
44
+ self.istft = lambda x, y: torch.istft(
45
+ torch.complex(x * torch.cos(y), x * torch.sin(y)),
46
+ n_fft=self.filter_length,
47
+ hop_length=self.hop_length,
48
+ win_length=self.win_length,
49
+ window=torch.hann_window(self.win_length, device=device),
50
+ )
51
+
52
+ with torch.no_grad():
53
+ bias_audio = vocoder(mel_input).float().squeeze(0)
54
+ bias_spec, _ = self.stft(bias_audio)
55
+
56
+ self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
57
+
58
+ @torch.inference_mode()
59
+ def forward(self, audio, strength=0.0005):
60
+ audio_spec, audio_angles = self.stft(audio)
61
+ audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
62
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
63
+ audio_denoised = self.istft(audio_spec_denoised, audio_angles)
64
+ return audio_denoised
diff_ttsg/hifigan/env.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import os
4
+ import shutil
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super(AttrDict, self).__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def build_env(config, config_name, path):
14
+ t_path = os.path.join(path, config_name)
15
+ if config != t_path:
16
+ os.makedirs(path, exist_ok=True)
17
+ shutil.copyfile(config, os.path.join(path, config_name))
diff_ttsg/hifigan/meldataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import math
4
+ import os
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from librosa.util import normalize
12
+ from scipy.io.wavfile import read
13
+
14
+ MAX_WAV_VALUE = 32768.0
15
+
16
+
17
+ def load_wav(full_path):
18
+ sampling_rate, data = read(full_path)
19
+ return data, sampling_rate
20
+
21
+
22
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
23
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24
+
25
+
26
+ def dynamic_range_decompression(x, C=1):
27
+ return np.exp(x) / C
28
+
29
+
30
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31
+ return torch.log(torch.clamp(x, min=clip_val) * C)
32
+
33
+
34
+ def dynamic_range_decompression_torch(x, C=1):
35
+ return torch.exp(x) / C
36
+
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ output = dynamic_range_decompression_torch(magnitudes)
45
+ return output
46
+
47
+
48
+ mel_basis = {}
49
+ hann_window = {}
50
+
51
+
52
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53
+ if torch.min(y) < -1.:
54
+ print('min value is ', torch.min(y))
55
+ if torch.max(y) > 1.:
56
+ print('max value is ', torch.max(y))
57
+
58
+ global mel_basis, hann_window
59
+ if fmax not in mel_basis:
60
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63
+
64
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
65
+ y = y.squeeze(1)
66
+
67
+ spec = torch.view_as_real(torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
68
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True))
69
+
70
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
71
+
72
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
73
+ spec = spectral_normalize_torch(spec)
74
+
75
+ return spec
76
+
77
+
78
+ def get_dataset_filelist(a):
79
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
80
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
81
+ for x in fi.read().split('\n') if len(x) > 0]
82
+
83
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
84
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
85
+ for x in fi.read().split('\n') if len(x) > 0]
86
+ return training_files, validation_files
87
+
88
+
89
+ class MelDataset(torch.utils.data.Dataset):
90
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
91
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
92
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
93
+ self.audio_files = training_files
94
+ random.seed(1234)
95
+ if shuffle:
96
+ random.shuffle(self.audio_files)
97
+ self.segment_size = segment_size
98
+ self.sampling_rate = sampling_rate
99
+ self.split = split
100
+ self.n_fft = n_fft
101
+ self.num_mels = num_mels
102
+ self.hop_size = hop_size
103
+ self.win_size = win_size
104
+ self.fmin = fmin
105
+ self.fmax = fmax
106
+ self.fmax_loss = fmax_loss
107
+ self.cached_wav = None
108
+ self.n_cache_reuse = n_cache_reuse
109
+ self._cache_ref_count = 0
110
+ self.device = device
111
+ self.fine_tuning = fine_tuning
112
+ self.base_mels_path = base_mels_path
113
+
114
+ def __getitem__(self, index):
115
+ filename = self.audio_files[index]
116
+ if self._cache_ref_count == 0:
117
+ audio, sampling_rate = load_wav(filename)
118
+ audio = audio / MAX_WAV_VALUE
119
+ if not self.fine_tuning:
120
+ audio = normalize(audio) * 0.95
121
+ self.cached_wav = audio
122
+ if sampling_rate != self.sampling_rate:
123
+ raise ValueError("{} SR doesn't match target {} SR".format(
124
+ sampling_rate, self.sampling_rate))
125
+ self._cache_ref_count = self.n_cache_reuse
126
+ else:
127
+ audio = self.cached_wav
128
+ self._cache_ref_count -= 1
129
+
130
+ audio = torch.FloatTensor(audio)
131
+ audio = audio.unsqueeze(0)
132
+
133
+ if not self.fine_tuning:
134
+ if self.split:
135
+ if audio.size(1) >= self.segment_size:
136
+ max_audio_start = audio.size(1) - self.segment_size
137
+ audio_start = random.randint(0, max_audio_start)
138
+ audio = audio[:, audio_start:audio_start+self.segment_size]
139
+ else:
140
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
141
+
142
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
143
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
144
+ center=False)
145
+ else:
146
+ mel = np.load(
147
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
148
+ mel = torch.from_numpy(mel)
149
+
150
+ if len(mel.shape) < 3:
151
+ mel = mel.unsqueeze(0)
152
+
153
+ if self.split:
154
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
155
+
156
+ if audio.size(1) >= self.segment_size:
157
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
158
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
159
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
160
+ else:
161
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
162
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
163
+
164
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
165
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
166
+ center=False)
167
+
168
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
169
+
170
+ def __len__(self):
171
+ return len(self.audio_files)
diff_ttsg/hifigan/models.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+
9
+ from .xutils import get_padding, init_weights
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ class ResBlock1(torch.nn.Module):
15
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
+ super(ResBlock1, self).__init__()
17
+ self.h = h
18
+ self.convs1 = nn.ModuleList([
19
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
20
+ padding=get_padding(kernel_size, dilation[0]))),
21
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
22
+ padding=get_padding(kernel_size, dilation[1]))),
23
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
24
+ padding=get_padding(kernel_size, dilation[2])))
25
+ ])
26
+ self.convs1.apply(init_weights)
27
+
28
+ self.convs2 = nn.ModuleList([
29
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
30
+ padding=get_padding(kernel_size, 1))),
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
32
+ padding=get_padding(kernel_size, 1))),
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
34
+ padding=get_padding(kernel_size, 1)))
35
+ ])
36
+ self.convs2.apply(init_weights)
37
+
38
+ def forward(self, x):
39
+ for c1, c2 in zip(self.convs1, self.convs2):
40
+ xt = F.leaky_relu(x, LRELU_SLOPE)
41
+ xt = c1(xt)
42
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
43
+ xt = c2(xt)
44
+ x = xt + x
45
+ return x
46
+
47
+ def remove_weight_norm(self):
48
+ for l in self.convs1:
49
+ remove_weight_norm(l)
50
+ for l in self.convs2:
51
+ remove_weight_norm(l)
52
+
53
+
54
+ class ResBlock2(torch.nn.Module):
55
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
56
+ super(ResBlock2, self).__init__()
57
+ self.h = h
58
+ self.convs = nn.ModuleList([
59
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
60
+ padding=get_padding(kernel_size, dilation[0]))),
61
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
62
+ padding=get_padding(kernel_size, dilation[1])))
63
+ ])
64
+ self.convs.apply(init_weights)
65
+
66
+ def forward(self, x):
67
+ for c in self.convs:
68
+ xt = F.leaky_relu(x, LRELU_SLOPE)
69
+ xt = c(xt)
70
+ x = xt + x
71
+ return x
72
+
73
+ def remove_weight_norm(self):
74
+ for l in self.convs:
75
+ remove_weight_norm(l)
76
+
77
+
78
+ class Generator(torch.nn.Module):
79
+ def __init__(self, h):
80
+ super(Generator, self).__init__()
81
+ self.h = h
82
+ self.num_kernels = len(h.resblock_kernel_sizes)
83
+ self.num_upsamples = len(h.upsample_rates)
84
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
85
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
86
+
87
+ self.ups = nn.ModuleList()
88
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
89
+ self.ups.append(weight_norm(
90
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
91
+ k, u, padding=(k-u)//2)))
92
+
93
+ self.resblocks = nn.ModuleList()
94
+ for i in range(len(self.ups)):
95
+ ch = h.upsample_initial_channel//(2**(i+1))
96
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
97
+ self.resblocks.append(resblock(h, ch, k, d))
98
+
99
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
100
+ self.ups.apply(init_weights)
101
+ self.conv_post.apply(init_weights)
102
+
103
+ def forward(self, x):
104
+ x = self.conv_pre(x)
105
+ for i in range(self.num_upsamples):
106
+ x = F.leaky_relu(x, LRELU_SLOPE)
107
+ x = self.ups[i](x)
108
+ xs = None
109
+ for j in range(self.num_kernels):
110
+ if xs is None:
111
+ xs = self.resblocks[i*self.num_kernels+j](x)
112
+ else:
113
+ xs += self.resblocks[i*self.num_kernels+j](x)
114
+ x = xs / self.num_kernels
115
+ x = F.leaky_relu(x)
116
+ x = self.conv_post(x)
117
+ x = torch.tanh(x)
118
+
119
+ return x
120
+
121
+ def remove_weight_norm(self):
122
+ print('Removing weight norm...')
123
+ for l in self.ups:
124
+ remove_weight_norm(l)
125
+ for l in self.resblocks:
126
+ l.remove_weight_norm()
127
+ remove_weight_norm(self.conv_pre)
128
+ remove_weight_norm(self.conv_post)
129
+
130
+
131
+ class DiscriminatorP(torch.nn.Module):
132
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
133
+ super(DiscriminatorP, self).__init__()
134
+ self.period = period
135
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
136
+ self.convs = nn.ModuleList([
137
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
138
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
139
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
140
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
141
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
142
+ ])
143
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
144
+
145
+ def forward(self, x):
146
+ fmap = []
147
+
148
+ # 1d to 2d
149
+ b, c, t = x.shape
150
+ if t % self.period != 0: # pad first
151
+ n_pad = self.period - (t % self.period)
152
+ x = F.pad(x, (0, n_pad), "reflect")
153
+ t = t + n_pad
154
+ x = x.view(b, c, t // self.period, self.period)
155
+
156
+ for l in self.convs:
157
+ x = l(x)
158
+ x = F.leaky_relu(x, LRELU_SLOPE)
159
+ fmap.append(x)
160
+ x = self.conv_post(x)
161
+ fmap.append(x)
162
+ x = torch.flatten(x, 1, -1)
163
+
164
+ return x, fmap
165
+
166
+
167
+ class MultiPeriodDiscriminator(torch.nn.Module):
168
+ def __init__(self):
169
+ super(MultiPeriodDiscriminator, self).__init__()
170
+ self.discriminators = nn.ModuleList([
171
+ DiscriminatorP(2),
172
+ DiscriminatorP(3),
173
+ DiscriminatorP(5),
174
+ DiscriminatorP(7),
175
+ DiscriminatorP(11),
176
+ ])
177
+
178
+ def forward(self, y, y_hat):
179
+ y_d_rs = []
180
+ y_d_gs = []
181
+ fmap_rs = []
182
+ fmap_gs = []
183
+ for i, d in enumerate(self.discriminators):
184
+ y_d_r, fmap_r = d(y)
185
+ y_d_g, fmap_g = d(y_hat)
186
+ y_d_rs.append(y_d_r)
187
+ fmap_rs.append(fmap_r)
188
+ y_d_gs.append(y_d_g)
189
+ fmap_gs.append(fmap_g)
190
+
191
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
192
+
193
+
194
+ class DiscriminatorS(torch.nn.Module):
195
+ def __init__(self, use_spectral_norm=False):
196
+ super(DiscriminatorS, self).__init__()
197
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
198
+ self.convs = nn.ModuleList([
199
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
200
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
201
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
202
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
203
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
204
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
205
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
206
+ ])
207
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
208
+
209
+ def forward(self, x):
210
+ fmap = []
211
+ for l in self.convs:
212
+ x = l(x)
213
+ x = F.leaky_relu(x, LRELU_SLOPE)
214
+ fmap.append(x)
215
+ x = self.conv_post(x)
216
+ fmap.append(x)
217
+ x = torch.flatten(x, 1, -1)
218
+
219
+ return x, fmap
220
+
221
+
222
+ class MultiScaleDiscriminator(torch.nn.Module):
223
+ def __init__(self):
224
+ super(MultiScaleDiscriminator, self).__init__()
225
+ self.discriminators = nn.ModuleList([
226
+ DiscriminatorS(use_spectral_norm=True),
227
+ DiscriminatorS(),
228
+ DiscriminatorS(),
229
+ ])
230
+ self.meanpools = nn.ModuleList([
231
+ AvgPool1d(4, 2, padding=2),
232
+ AvgPool1d(4, 2, padding=2)
233
+ ])
234
+
235
+ def forward(self, y, y_hat):
236
+ y_d_rs = []
237
+ y_d_gs = []
238
+ fmap_rs = []
239
+ fmap_gs = []
240
+ for i, d in enumerate(self.discriminators):
241
+ if i != 0:
242
+ y = self.meanpools[i-1](y)
243
+ y_hat = self.meanpools[i-1](y_hat)
244
+ y_d_r, fmap_r = d(y)
245
+ y_d_g, fmap_g = d(y_hat)
246
+ y_d_rs.append(y_d_r)
247
+ fmap_rs.append(fmap_r)
248
+ y_d_gs.append(y_d_g)
249
+ fmap_gs.append(fmap_g)
250
+
251
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
252
+
253
+
254
+ def feature_loss(fmap_r, fmap_g):
255
+ loss = 0
256
+ for dr, dg in zip(fmap_r, fmap_g):
257
+ for rl, gl in zip(dr, dg):
258
+ loss += torch.mean(torch.abs(rl - gl))
259
+
260
+ return loss*2
261
+
262
+
263
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
264
+ loss = 0
265
+ r_losses = []
266
+ g_losses = []
267
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
268
+ r_loss = torch.mean((1-dr)**2)
269
+ g_loss = torch.mean(dg**2)
270
+ loss += (r_loss + g_loss)
271
+ r_losses.append(r_loss.item())
272
+ g_losses.append(g_loss.item())
273
+
274
+ return loss, r_losses, g_losses
275
+
276
+
277
+ def generator_loss(disc_outputs):
278
+ loss = 0
279
+ gen_losses = []
280
+ for dg in disc_outputs:
281
+ l = torch.mean((1-dg)**2)
282
+ gen_losses.append(l)
283
+ loss += l
284
+
285
+ return loss, gen_losses
286
+
diff_ttsg/hifigan/xutils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+ import matplotlib
6
+ import torch
7
+ from torch.nn.utils import weight_norm
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pylab as plt
10
+
11
+
12
+ def plot_spectrogram(spectrogram):
13
+ fig, ax = plt.subplots(figsize=(10, 2))
14
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
15
+ interpolation='none')
16
+ plt.colorbar(im, ax=ax)
17
+
18
+ fig.canvas.draw()
19
+ plt.close()
20
+
21
+ return fig
22
+
23
+
24
+ def init_weights(m, mean=0.0, std=0.01):
25
+ classname = m.__class__.__name__
26
+ if classname.find("Conv") != -1:
27
+ m.weight.data.normal_(mean, std)
28
+
29
+
30
+ def apply_weight_norm(m):
31
+ classname = m.__class__.__name__
32
+ if classname.find("Conv") != -1:
33
+ weight_norm(m)
34
+
35
+
36
+ def get_padding(kernel_size, dilation=1):
37
+ return int((kernel_size*dilation - dilation)/2)
38
+
39
+
40
+ def load_checkpoint(filepath, device):
41
+ assert os.path.isfile(filepath)
42
+ print("Loading '{}'".format(filepath))
43
+ checkpoint_dict = torch.load(filepath, map_location=device)
44
+ print("Complete.")
45
+ return checkpoint_dict
46
+
47
+
48
+ def save_checkpoint(filepath, obj):
49
+ print("Saving checkpoint to {}".format(filepath))
50
+ torch.save(obj, filepath)
51
+ print("Complete.")
52
+
53
+
54
+ def scan_checkpoint(cp_dir, prefix):
55
+ pattern = os.path.join(cp_dir, prefix + '????????')
56
+ cp_list = glob.glob(pattern)
57
+ if len(cp_list) == 0:
58
+ return None
59
+ return sorted(cp_list)[-1]
60
+
diff_ttsg/models/__init__.py ADDED
File without changes
diff_ttsg/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
diff_ttsg/models/__pycache__/diff_ttsg.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
diff_ttsg/models/components/__init__.py ADDED
File without changes
diff_ttsg/models/components/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
diff_ttsg/models/components/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
diff_ttsg/models/components/__pycache__/text_encoder.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
diff_ttsg/models/components/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.03 kB). View file
 
diff_ttsg/models/components/diffusion.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ import math
10
+
11
+ import torch
12
+ from diffusers import UNet1DModel
13
+ from einops import pack, rearrange
14
+
15
+
16
+ class Mish(torch.nn.Module):
17
+ def forward(self, x):
18
+ return x * torch.tanh(torch.nn.functional.softplus(x))
19
+
20
+
21
+ class Upsample(torch.nn.Module):
22
+ def __init__(self, dim):
23
+ super(Upsample, self).__init__()
24
+ self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
25
+
26
+ def forward(self, x):
27
+ return self.conv(x)
28
+
29
+
30
+ class Downsample(torch.nn.Module):
31
+ def __init__(self, dim):
32
+ super(Downsample, self).__init__()
33
+ self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
34
+
35
+ def forward(self, x):
36
+ return self.conv(x)
37
+
38
+
39
+ class Rezero(torch.nn.Module):
40
+ def __init__(self, fn):
41
+ super(Rezero, self).__init__()
42
+ self.fn = fn
43
+ self.g = torch.nn.Parameter(torch.zeros(1))
44
+
45
+ def forward(self, x):
46
+ return self.fn(x) * self.g
47
+
48
+
49
+ class Block(torch.nn.Module):
50
+ def __init__(self, dim, dim_out, groups=8):
51
+ super(Block, self).__init__()
52
+ self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
53
+ padding=1), torch.nn.GroupNorm(
54
+ groups, dim_out), Mish())
55
+
56
+ def forward(self, x, mask):
57
+ output = self.block(x * mask)
58
+ return output * mask
59
+
60
+
61
+ class ResnetBlock(torch.nn.Module):
62
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
63
+ super(ResnetBlock, self).__init__()
64
+ self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
65
+ dim_out))
66
+
67
+ self.block1 = Block(dim, dim_out, groups=groups)
68
+ self.block2 = Block(dim_out, dim_out, groups=groups)
69
+ if dim != dim_out:
70
+ self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
71
+ else:
72
+ self.res_conv = torch.nn.Identity()
73
+
74
+ def forward(self, x, mask, time_emb):
75
+ h = self.block1(x, mask)
76
+ h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
77
+ h = self.block2(h, mask)
78
+ output = h + self.res_conv(x * mask)
79
+ return output
80
+
81
+
82
+ class LinearAttention(torch.nn.Module):
83
+ def __init__(self, dim, heads=4, dim_head=32):
84
+ super(LinearAttention, self).__init__()
85
+ self.heads = heads
86
+ hidden_dim = dim_head * heads
87
+ self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
88
+ self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
89
+
90
+ def forward(self, x):
91
+ b, c, h, w = x.shape
92
+ qkv = self.to_qkv(x)
93
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
94
+ heads = self.heads, qkv=3)
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
97
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
98
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
99
+ heads=self.heads, h=h, w=w)
100
+ return self.to_out(out)
101
+
102
+
103
+ class Residual(torch.nn.Module):
104
+ def __init__(self, fn):
105
+ super(Residual, self).__init__()
106
+ self.fn = fn
107
+
108
+ def forward(self, x, *args, **kwargs):
109
+ output = self.fn(x, *args, **kwargs) + x
110
+ return output
111
+
112
+
113
+ class UNet1DDiffuser(torch.nn.Module):
114
+ def __init__(self, in_channels=90, out_channels=45, block_out_channels=(256, 512)):
115
+ super(UNet1DDiffuser, self).__init__()
116
+
117
+ self.unet = UNet1DModel(
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ down_block_types = ("DownBlock1DNoSkip", "AttnDownBlock1D"),
121
+ up_block_types = ("AttnUpBlock1D", "UpBlock1DNoSkip"),
122
+ mid_block_type = "UNetMidBlock1D",
123
+ block_out_channels=block_out_channels,
124
+ use_timestep_embedding=True,
125
+ )
126
+
127
+
128
+ def forward(self, x, mask, mu, t, spk=None):
129
+ x = pack([x, mu], "b * t")[0]
130
+
131
+ return self.unet(x, t).sample * mask
132
+
133
+ class SinusoidalPosEmb(torch.nn.Module):
134
+ def __init__(self, dim):
135
+ super(SinusoidalPosEmb, self).__init__()
136
+ self.dim = dim
137
+
138
+ def forward(self, x, scale=1000):
139
+ device = x.device
140
+ half_dim = self.dim // 2
141
+ emb = math.log(10000) / (half_dim - 1)
142
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
143
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
144
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
145
+ return emb
146
+
147
+
148
+ class GradLogPEstimator2d(torch.nn.Module):
149
+ def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
150
+ n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
151
+ super(GradLogPEstimator2d, self).__init__()
152
+ self.dim = dim
153
+ self.dim_mults = dim_mults
154
+ self.groups = groups
155
+ self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
156
+ self.spk_emb_dim = spk_emb_dim
157
+ self.pe_scale = pe_scale
158
+
159
+ if n_spks > 1:
160
+ self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
161
+ torch.nn.Linear(spk_emb_dim * 4, n_feats))
162
+ self.time_pos_emb = SinusoidalPosEmb(dim)
163
+ self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
164
+ torch.nn.Linear(dim * 4, dim))
165
+
166
+ dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
167
+ in_out = list(zip(dims[:-1], dims[1:]))
168
+ self.downs = torch.nn.ModuleList([])
169
+ self.ups = torch.nn.ModuleList([])
170
+ num_resolutions = len(in_out)
171
+
172
+ for ind, (dim_in, dim_out) in enumerate(in_out):
173
+ is_last = ind >= (num_resolutions - 1)
174
+ self.downs.append(torch.nn.ModuleList([
175
+ ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
176
+ ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
177
+ Residual(Rezero(LinearAttention(dim_out))),
178
+ Downsample(dim_out) if not is_last else torch.nn.Identity()]))
179
+
180
+ mid_dim = dims[-1]
181
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
182
+ self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
183
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
184
+
185
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
186
+ self.ups.append(torch.nn.ModuleList([
187
+ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
188
+ ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
189
+ Residual(Rezero(LinearAttention(dim_in))),
190
+ Upsample(dim_in)]))
191
+ self.final_block = Block(dim, dim)
192
+ self.final_conv = torch.nn.Conv2d(dim, 1, 1)
193
+
194
+ def forward(self, x, mask, mu, t, spk=None):
195
+ if not isinstance(spk, type(None)):
196
+ s = self.spk_mlp(spk)
197
+
198
+ t = self.time_pos_emb(t, scale=self.pe_scale)
199
+ t = self.mlp(t)
200
+
201
+ if self.n_spks < 2:
202
+ x = torch.stack([mu, x], 1)
203
+ else:
204
+ s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
205
+ x = torch.stack([mu, x, s], 1)
206
+ mask = mask.unsqueeze(1)
207
+
208
+ hiddens = []
209
+ masks = [mask]
210
+ for resnet1, resnet2, attn, downsample in self.downs:
211
+ mask_down = masks[-1]
212
+ x = resnet1(x, mask_down, t)
213
+ x = resnet2(x, mask_down, t)
214
+ x = attn(x)
215
+ hiddens.append(x)
216
+ x = downsample(x * mask_down)
217
+ masks.append(mask_down[:, :, :, ::2])
218
+ masks = masks[:-1]
219
+ mask_mid = masks[-1]
220
+ x = self.mid_block1(x, mask_mid, t)
221
+ x = self.mid_attn(x)
222
+ x = self.mid_block2(x, mask_mid, t)
223
+
224
+ for resnet1, resnet2, attn, upsample in self.ups:
225
+ mask_up = masks.pop()
226
+ x = torch.cat((x, hiddens.pop()), dim=1)
227
+ x = resnet1(x, mask_up, t)
228
+ x = resnet2(x, mask_up, t)
229
+ x = attn(x)
230
+ x = upsample(x * mask_up)
231
+
232
+ x = self.final_block(x, mask)
233
+ output = self.final_conv(x * mask)
234
+
235
+ return (output * mask).squeeze(1)
236
+
237
+
238
+ def get_noise(t, beta_init, beta_term, cumulative=False):
239
+ if cumulative:
240
+ noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
241
+ else:
242
+ noise = beta_init + (beta_term - beta_init)*t
243
+ return noise
244
+
245
+
246
+ class Diffusion(torch.nn.Module):
247
+ def __init__(self, n_feats, dim,
248
+ n_spks=1, spk_emb_dim=64,
249
+ beta_min=0.05, beta_max=20, pe_scale=1000):
250
+ super(Diffusion, self).__init__()
251
+ self.n_feats = n_feats
252
+ self.dim = dim
253
+ self.n_spks = n_spks
254
+ self.spk_emb_dim = spk_emb_dim
255
+ self.beta_min = beta_min
256
+ self.beta_max = beta_max
257
+ self.pe_scale = pe_scale
258
+
259
+ self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
260
+ spk_emb_dim=spk_emb_dim,
261
+ pe_scale=pe_scale)
262
+
263
+ def forward_diffusion(self, x0, mask, mu, t):
264
+ time = t.unsqueeze(-1).unsqueeze(-1)
265
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
266
+ mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
267
+ variance = 1.0 - torch.exp(-cum_noise)
268
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
269
+ requires_grad=False)
270
+ xt = mean + z * torch.sqrt(variance)
271
+ return xt * mask, z * mask
272
+
273
+ @torch.no_grad()
274
+ def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
275
+ h = 1.0 / n_timesteps
276
+ xt = z * mask
277
+ for i in range(n_timesteps):
278
+ t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
279
+ device=z.device)
280
+ time = t.unsqueeze(-1).unsqueeze(-1)
281
+ noise_t = get_noise(time, self.beta_min, self.beta_max,
282
+ cumulative=False)
283
+ if stoc: # adds stochastic term
284
+ dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
285
+ dxt_det = dxt_det * noise_t * h
286
+ dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
287
+ requires_grad=False)
288
+ dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
289
+ dxt = dxt_det + dxt_stoc
290
+ else:
291
+ dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
292
+ dxt = dxt * noise_t * h
293
+ xt = (xt - dxt) * mask
294
+ return xt
295
+
296
+ @torch.no_grad()
297
+ def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
298
+ return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
299
+
300
+ def loss_t(self, x0, mask, mu, t, spk=None):
301
+ xt, z = self.forward_diffusion(x0, mask, mu, t)
302
+ time = t.unsqueeze(-1).unsqueeze(-1) # t =[0.6215, 0.0191, 0.0391]
303
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
304
+ noise_estimation = self.estimator(xt, mask, mu, t, spk) # xt = [3, 80, 172], mask=[3, 1, 172], mu=[3, 80, 172], t=[3]
305
+ noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
306
+ loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
307
+ return loss, xt
308
+
309
+ def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
310
+ t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
311
+ requires_grad=False)
312
+ t = torch.clamp(t, offset, 1.0 - offset)
313
+ return self.loss_t(x0, mask, mu, t, spk)
314
+
315
+
316
+
317
+ class Diffusion_Motion(torch.nn.Module):
318
+ def __init__(self, in_channels, motion_decoder_channels=(256, 256), beta_min=0.05, beta_max=20):
319
+ super(Diffusion_Motion, self).__init__()
320
+ self.in_channels = in_channels
321
+ self.beta_min = beta_min
322
+ self.beta_max = beta_max
323
+
324
+ self.estimator = UNet1DDiffuser(block_out_channels=motion_decoder_channels)
325
+
326
+ def forward_diffusion(self, x0, mask, mu, t):
327
+ time = t.unsqueeze(-1).unsqueeze(-1)
328
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
329
+ mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
330
+ variance = 1.0 - torch.exp(-cum_noise)
331
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
332
+ requires_grad=False)
333
+ xt = mean + z * torch.sqrt(variance)
334
+ return xt * mask, z * mask
335
+
336
+ @torch.no_grad()
337
+ def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
338
+ h = 1.0 / n_timesteps
339
+ xt = z * mask
340
+ for i in range(n_timesteps):
341
+ t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
342
+ device=z.device)
343
+ time = t.unsqueeze(-1).unsqueeze(-1)
344
+ noise_t = get_noise(time, self.beta_min, self.beta_max,
345
+ cumulative=False)
346
+ if stoc: # adds stochastic term
347
+ dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
348
+ dxt_det = dxt_det * noise_t * h
349
+ dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
350
+ requires_grad=False)
351
+ dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
352
+ dxt = dxt_det + dxt_stoc
353
+ else:
354
+ dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
355
+ dxt = dxt * noise_t * h
356
+ xt = (xt - dxt) * mask
357
+ return xt
358
+
359
+ @torch.no_grad()
360
+ def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
361
+ return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
362
+
363
+ def loss_t(self, x0, mask, mu, t, spk=None):
364
+ xt, z = self.forward_diffusion(x0, mask, mu, t)
365
+ time = t.unsqueeze(-1).unsqueeze(-1) # t =[0.6215, 0.0191, 0.0391]
366
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
367
+ noise_estimation = self.estimator(xt, mask, mu, t, spk) # xt = [3, 80, 172], mask=[3, 1, 172], mu=[3, 80, 172], t=[3]
368
+ noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
369
+ loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.in_channels)
370
+ return loss, xt
371
+
372
+ def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
373
+ t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
374
+ requires_grad=False)
375
+ t = torch.clamp(t, offset, 1.0 - offset)
376
+ return self.loss_t(x0, mask, mu, t, spk)
diff_ttsg/models/components/text_encoder.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from conformer import ConformerBlock
8
+ from einops import rearrange
9
+
10
+ from diff_ttsg.models.components.transformer import FFTransformer
11
+ from diff_ttsg.utils.model import convert_pad_shape, sequence_mask
12
+
13
+
14
+ class LayerNorm(nn.Module):
15
+ def __init__(self, channels, eps=1e-4):
16
+ super(LayerNorm, self).__init__()
17
+ self.channels = channels
18
+ self.eps = eps
19
+
20
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
21
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
22
+
23
+ def forward(self, x):
24
+ n_dims = len(x.shape)
25
+ mean = torch.mean(x, 1, keepdim=True)
26
+ variance = torch.mean((x - mean)**2, 1, keepdim=True)
27
+
28
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
29
+
30
+ shape = [1, -1] + [1] * (n_dims - 2)
31
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
32
+ return x
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
37
+ n_layers, p_dropout):
38
+ super(ConvReluNorm, self).__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+
46
+ self.conv_layers = torch.nn.ModuleList()
47
+ self.norm_layers = torch.nn.ModuleList()
48
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
49
+ kernel_size, padding=kernel_size//2))
50
+ self.norm_layers.append(LayerNorm(hidden_channels))
51
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
52
+ for _ in range(n_layers - 1):
53
+ self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
54
+ kernel_size, padding=kernel_size//2))
55
+ self.norm_layers.append(LayerNorm(hidden_channels))
56
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
57
+ self.proj.weight.data.zero_()
58
+ self.proj.bias.data.zero_()
59
+
60
+ def forward(self, x, x_mask):
61
+ x_org = x
62
+ for i in range(self.n_layers):
63
+ x = self.conv_layers[i](x * x_mask)
64
+ x = self.norm_layers[i](x)
65
+ x = self.relu_drop(x)
66
+ x = x_org + self.proj(x)
67
+ return x * x_mask
68
+
69
+
70
+ class DurationPredictor(nn.Module):
71
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
72
+ super(DurationPredictor, self).__init__()
73
+ self.in_channels = in_channels
74
+ self.filter_channels = filter_channels
75
+ self.p_dropout = p_dropout
76
+
77
+ self.drop = torch.nn.Dropout(p_dropout)
78
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
79
+ kernel_size, padding=kernel_size//2)
80
+ self.norm_1 = LayerNorm(filter_channels)
81
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
82
+ kernel_size, padding=kernel_size//2)
83
+ self.norm_2 = LayerNorm(filter_channels)
84
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
85
+
86
+ def forward(self, x, x_mask):
87
+ x = self.conv_1(x * x_mask)
88
+ x = torch.relu(x)
89
+ x = self.norm_1(x)
90
+ x = self.drop(x)
91
+ x = self.conv_2(x * x_mask)
92
+ x = torch.relu(x)
93
+ x = self.norm_2(x)
94
+ x = self.drop(x)
95
+ x = self.proj(x * x_mask)
96
+ return x * x_mask
97
+
98
+
99
+ class MultiHeadAttention(nn.Module):
100
+ def __init__(self, channels, out_channels, n_heads, window_size=None,
101
+ heads_share=True, p_dropout=0.0, proximal_bias=False,
102
+ proximal_init=False):
103
+ super(MultiHeadAttention, self).__init__()
104
+ assert channels % n_heads == 0
105
+
106
+ self.channels = channels
107
+ self.out_channels = out_channels
108
+ self.n_heads = n_heads
109
+ self.window_size = window_size
110
+ self.heads_share = heads_share
111
+ self.proximal_bias = proximal_bias
112
+ self.p_dropout = p_dropout
113
+ self.attn = None
114
+
115
+ self.k_channels = channels // n_heads
116
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
117
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
118
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
119
+ if window_size is not None:
120
+ n_heads_rel = 1 if heads_share else n_heads
121
+ rel_stddev = self.k_channels**-0.5
122
+ self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
123
+ window_size * 2 + 1, self.k_channels) * rel_stddev)
124
+ self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
125
+ window_size * 2 + 1, self.k_channels) * rel_stddev)
126
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
127
+ self.drop = torch.nn.Dropout(p_dropout)
128
+
129
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
130
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
131
+ if proximal_init:
132
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
133
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
134
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
135
+
136
+ def forward(self, x, c, attn_mask=None):
137
+ q = self.conv_q(x)
138
+ k = self.conv_k(c)
139
+ v = self.conv_v(c)
140
+
141
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
142
+
143
+ x = self.conv_o(x)
144
+ return x
145
+
146
+ def attention(self, query, key, value, mask=None):
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
157
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
158
+ scores_local = rel_logits / math.sqrt(self.k_channels)
159
+ scores = scores + scores_local
160
+ if self.proximal_bias:
161
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
162
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
163
+ dtype=scores.dtype)
164
+ if mask is not None:
165
+ scores = scores.masked_fill(mask == 0, -1e4)
166
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
167
+ p_attn = self.drop(p_attn)
168
+ output = torch.matmul(p_attn, value)
169
+ if self.window_size is not None:
170
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
171
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
172
+ output = output + self._matmul_with_relative_values(relative_weights,
173
+ value_relative_embeddings)
174
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
175
+ return output, p_attn
176
+
177
+ def _matmul_with_relative_values(self, x, y):
178
+ ret = torch.matmul(x, y.unsqueeze(0))
179
+ return ret
180
+
181
+ def _matmul_with_relative_keys(self, x, y):
182
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
183
+ return ret
184
+
185
+ def _get_relative_embeddings(self, relative_embeddings, length):
186
+ pad_length = max(length - (self.window_size + 1), 0)
187
+ slice_start_position = max((self.window_size + 1) - length, 0)
188
+ slice_end_position = slice_start_position + 2 * length - 1
189
+ if pad_length > 0:
190
+ padded_relative_embeddings = torch.nn.functional.pad(
191
+ relative_embeddings, convert_pad_shape([[0, 0],
192
+ [pad_length, pad_length], [0, 0]]))
193
+ else:
194
+ padded_relative_embeddings = relative_embeddings
195
+ used_relative_embeddings = padded_relative_embeddings[:,
196
+ slice_start_position:slice_end_position]
197
+ return used_relative_embeddings
198
+
199
+ def _relative_position_to_absolute_position(self, x):
200
+ batch, heads, length, _ = x.size()
201
+ x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
202
+ x_flat = x.view([batch, heads, length * 2 * length])
203
+ x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
204
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
205
+ return x_final
206
+
207
+ def _absolute_position_to_relative_position(self, x):
208
+ batch, heads, length, _ = x.size()
209
+ x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
210
+ x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
211
+ x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
212
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
213
+ return x_final
214
+
215
+ def _attention_bias_proximal(self, length):
216
+ r = torch.arange(length, dtype=torch.float32)
217
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
218
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
219
+
220
+
221
+ class FFN(nn.Module):
222
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
223
+ p_dropout=0.0):
224
+ super(FFN, self).__init__()
225
+ self.in_channels = in_channels
226
+ self.out_channels = out_channels
227
+ self.filter_channels = filter_channels
228
+ self.kernel_size = kernel_size
229
+ self.p_dropout = p_dropout
230
+
231
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
232
+ padding=kernel_size//2)
233
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
234
+ padding=kernel_size//2)
235
+ self.drop = torch.nn.Dropout(p_dropout)
236
+
237
+ def forward(self, x, x_mask):
238
+ x = self.conv_1(x * x_mask)
239
+ x = torch.relu(x)
240
+ x = self.drop(x)
241
+ x = self.conv_2(x * x_mask)
242
+ return x * x_mask
243
+
244
+
245
+ class Encoder(nn.Module):
246
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
247
+ kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
248
+ super(Encoder, self).__init__()
249
+ self.hidden_channels = hidden_channels
250
+ self.filter_channels = filter_channels
251
+ self.n_heads = n_heads
252
+ self.n_layers = n_layers
253
+ self.kernel_size = kernel_size
254
+ self.p_dropout = p_dropout
255
+ self.window_size = window_size
256
+
257
+ self.drop = torch.nn.Dropout(p_dropout)
258
+ self.attn_layers = torch.nn.ModuleList()
259
+ self.norm_layers_1 = torch.nn.ModuleList()
260
+ self.ffn_layers = torch.nn.ModuleList()
261
+ self.norm_layers_2 = torch.nn.ModuleList()
262
+ for _ in range(self.n_layers):
263
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
264
+ n_heads, window_size=window_size, p_dropout=p_dropout))
265
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
266
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
267
+ filter_channels, kernel_size, p_dropout=p_dropout))
268
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
269
+
270
+ def forward(self, x, x_mask):
271
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
272
+ for i in range(self.n_layers):
273
+ x = x * x_mask
274
+ y = self.attn_layers[i](x, x, attn_mask)
275
+ y = self.drop(y)
276
+ x = self.norm_layers_1[i](x + y)
277
+ y = self.ffn_layers[i](x, x_mask)
278
+ y = self.drop(y)
279
+ x = self.norm_layers_2[i](x + y)
280
+ x = x * x_mask
281
+ return x
282
+
283
+
284
+ class TextEncoder(nn.Module):
285
+ def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
286
+ filter_channels_dp, n_heads, n_layers, kernel_size,
287
+ p_dropout, window_size=None, spk_emb_dim=64, n_spks=1, encoder_type=None):
288
+ super(TextEncoder, self).__init__()
289
+ self.n_vocab = n_vocab
290
+ self.n_feats = n_feats
291
+ self.n_channels = n_channels
292
+ self.filter_channels = filter_channels
293
+ self.filter_channels_dp = filter_channels_dp
294
+ self.n_heads = n_heads
295
+ self.n_layers = n_layers
296
+ self.kernel_size = kernel_size
297
+ self.p_dropout = p_dropout
298
+ self.window_size = window_size
299
+ self.spk_emb_dim = spk_emb_dim
300
+ self.n_spks = n_spks
301
+
302
+ self.emb = torch.nn.Embedding(n_vocab, n_channels)
303
+ torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
304
+
305
+ self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
306
+ kernel_size=5, n_layers=3, p_dropout=0.5)
307
+ if encoder_type == "default":
308
+ self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
309
+ kernel_size, p_dropout, window_size=window_size)
310
+ elif encoder_type == "myencoder":
311
+ self.encoder = FFTransformer(
312
+ n_layers, n_heads, n_channels + (spk_emb_dim if n_spks > 1 else 0), 64, 1024, kernel_size,
313
+ p_dropout, p_dropout, rel_attention=False, rel_window_size=window_size
314
+ )
315
+ else:
316
+ raise ValueError(f"Unknown encoder type: {encoder_type}")
317
+
318
+
319
+ self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
320
+ self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
321
+ kernel_size, p_dropout)
322
+
323
+ def forward(self, x, x_lengths, spk=None):
324
+ x = self.emb(x) * math.sqrt(self.n_channels)
325
+ x = torch.transpose(x, 1, -1)
326
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
327
+
328
+ x = self.prenet(x, x_mask)
329
+ if self.n_spks > 1:
330
+ x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
331
+ x = self.encoder(x, x_mask)
332
+ mu = self.proj_m(x) * x_mask
333
+
334
+ x_dp = torch.detach(x)
335
+ logw = self.proj_w(x_dp, x_mask)
336
+
337
+ return mu, logw, x_mask
338
+
339
+ class MuMotionEncoder(nn.Module):
340
+ def __init__(
341
+ self,
342
+ input_channels,
343
+ output_channels,
344
+ hidden_channels,
345
+ d_head,
346
+ n_layer,
347
+ n_head,
348
+ ff_mult,
349
+ conv_expansion_factor,
350
+ dropout,
351
+ dropatt,
352
+ dropconv,
353
+ conv_kernel_size,
354
+ ) -> None:
355
+ super().__init__()
356
+
357
+ self.in_projection = nn.Conv1d(input_channels, hidden_channels, 1)
358
+ self.layers = nn.ModuleList()
359
+ for _ in range(n_layer):
360
+ self.layers.append(
361
+ ConformerBlock(
362
+ dim=hidden_channels,
363
+ dim_head=d_head,
364
+ heads=n_head,
365
+ ff_mult=ff_mult,
366
+ conv_expansion_factor=conv_expansion_factor,
367
+ ff_dropout=dropout,
368
+ attn_dropout=dropatt,
369
+ conv_dropout=dropconv,
370
+ conv_kernel_size=conv_kernel_size,
371
+ )
372
+ )
373
+
374
+ self.motion_projection = nn.Conv1d(hidden_channels, output_channels, 1)
375
+
376
+ def forward(self, x, mask):
377
+ x = self.in_projection(x)
378
+ x = rearrange(x, "b c t -> b t c")
379
+ mask = rearrange(mask, "b 1 t -> b (1 t)").bool()
380
+ for layer in self.layers:
381
+ x = layer(x, mask)
382
+ x = rearrange(x, "b t c -> b c t")
383
+ x = self.motion_projection(x)
384
+ return x
diff_ttsg/models/components/transformer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from conformer.conformer import Attention as RelAttention
19
+ from einops import rearrange
20
+
21
+
22
+ class PositionalEmbedding(nn.Module):
23
+ def __init__(self, demb):
24
+ super().__init__()
25
+ self.demb = demb
26
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
27
+ self.register_buffer("inv_freq", inv_freq)
28
+
29
+ def forward(self, pos_seq, bsz=None):
30
+ sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0))
31
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
32
+ if bsz is not None:
33
+ return pos_emb[None, :, :].expand(bsz, -1, -1)
34
+ else:
35
+ return pos_emb[None, :, :]
36
+
37
+
38
+ class PositionwiseConvFF(nn.Module):
39
+ def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
40
+ super().__init__()
41
+
42
+ self.d_model = d_model
43
+ self.d_inner = d_inner
44
+ self.dropout = dropout
45
+
46
+ self.CoreNet = nn.Sequential(
47
+ nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
48
+ nn.ReLU(),
49
+ # nn.Dropout(dropout), # worse convergence
50
+ nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
51
+ nn.Dropout(dropout),
52
+ )
53
+ self.layer_norm = nn.LayerNorm(d_model)
54
+ self.pre_lnorm = pre_lnorm
55
+
56
+ def forward(self, inp):
57
+ return self._forward(inp)
58
+
59
+ def _forward(self, inp):
60
+ if self.pre_lnorm:
61
+ # layer normalization + positionwise feed-forward
62
+ # core_out = inp
63
+ core_out = self.CoreNet(self.layer_norm(inp).transpose(1, 2))
64
+ core_out = core_out.transpose(1, 2)
65
+
66
+ # residual connection
67
+ output = core_out + inp
68
+ else:
69
+ # positionwise feed-forward
70
+ core_out = inp.transpose(1, 2)
71
+ core_out = self.CoreNet(core_out)
72
+ core_out = core_out.transpose(1, 2)
73
+
74
+ # residual connection + layer normalization
75
+ output = self.layer_norm(inp + core_out).to(inp.dtype)
76
+
77
+ return output
78
+
79
+
80
+ class MultiHeadAttn(nn.Module):
81
+ def __init__(
82
+ self, n_head, d_model, d_head, dropout, rel_attention, dropatt=0.1, pre_lnorm=True, rel_window_size=10
83
+ ):
84
+ super().__init__()
85
+
86
+ self.n_head = n_head
87
+ self.d_model = d_model
88
+ self.d_head = d_head
89
+ self.scale = 1 / (d_head**0.5)
90
+ self.pre_lnorm = pre_lnorm
91
+ self.rel_attention = rel_attention
92
+ if rel_attention:
93
+ self.attn = RelAttention(d_model, n_head, d_head, dropout, max_pos_emb=rel_window_size)
94
+ else:
95
+ self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
96
+ self.drop = nn.Dropout(dropout)
97
+ self.dropatt = nn.Dropout(dropatt)
98
+ self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
99
+
100
+ self.layer_norm = nn.LayerNorm(d_model)
101
+
102
+ def forward(self, inp, attn_mask=None):
103
+ return self._forward(inp, attn_mask)
104
+
105
+ def _forward(self, inp, attn_mask=None):
106
+ residual = inp
107
+
108
+ if self.pre_lnorm:
109
+ # layer normalization
110
+ inp = self.layer_norm(inp)
111
+
112
+ if not self.rel_attention:
113
+ n_head, d_head = self.n_head, self.d_head
114
+
115
+ head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
116
+ head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
117
+ head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
118
+ head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
119
+
120
+ q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
121
+ k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
122
+ v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
123
+
124
+ attn_score = torch.bmm(q, k.transpose(1, 2))
125
+ attn_score.mul_(self.scale)
126
+
127
+ if attn_mask is not None:
128
+ attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
129
+ attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
130
+ attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
131
+
132
+ attn_prob = F.softmax(attn_score, dim=2)
133
+ attn_prob = self.dropatt(attn_prob)
134
+ attn_vec = torch.bmm(attn_prob, v)
135
+
136
+ attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
137
+ attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head)
138
+
139
+ # linear projection
140
+ attn_out = self.o_net(attn_vec)
141
+ attn_out = self.drop(attn_out)
142
+ else:
143
+ attn_out = self.attn(inp, mask=attn_mask)
144
+
145
+ if self.pre_lnorm:
146
+ # residual connection
147
+ output = residual + attn_out
148
+ else:
149
+ # residual connection + layer normalization
150
+ output = self.layer_norm(residual + attn_out)
151
+
152
+ output = output.to(attn_out.dtype)
153
+
154
+ return output
155
+
156
+
157
+ class TransformerLayer(nn.Module):
158
+ def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs):
159
+ super().__init__()
160
+
161
+ self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
162
+ self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm"))
163
+
164
+ def forward(self, dec_inp, mask=None):
165
+ output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
166
+ output *= mask
167
+ output = self.pos_ff(output)
168
+ output *= mask
169
+ return output
170
+
171
+
172
+ class FFTransformer(nn.Module):
173
+ def __init__(
174
+ self,
175
+ n_layer,
176
+ n_head,
177
+ hidden_channels,
178
+ d_head,
179
+ d_inner,
180
+ kernel_size,
181
+ dropout,
182
+ dropatt,
183
+ dropemb=0.0,
184
+ embed_input=False,
185
+ n_embed=None,
186
+ d_embed=None,
187
+ padding_idx=0,
188
+ pre_lnorm=True,
189
+ rel_attention=True,
190
+ rel_window_size=10,
191
+ ):
192
+ super().__init__()
193
+ self.d_model = hidden_channels
194
+ self.n_head = n_head
195
+ self.d_head = d_head
196
+ self.padding_idx = padding_idx
197
+
198
+ if embed_input:
199
+ self.word_emb = nn.Embedding(n_embed, d_embed or hidden_channels, padding_idx=self.padding_idx)
200
+ else:
201
+ self.word_emb = None
202
+
203
+ self.rel_attention = rel_attention
204
+
205
+ if not rel_attention:
206
+ self.pos_emb = PositionalEmbedding(self.d_model)
207
+
208
+ self.drop = nn.Dropout(dropemb)
209
+ self.layers = nn.ModuleList()
210
+
211
+ for _ in range(n_layer):
212
+ self.layers.append(
213
+ TransformerLayer(
214
+ n_head,
215
+ hidden_channels,
216
+ d_head,
217
+ d_inner,
218
+ kernel_size,
219
+ dropout,
220
+ dropatt=dropatt,
221
+ pre_lnorm=pre_lnorm,
222
+ rel_attention=rel_attention,
223
+ rel_window_size=rel_window_size,
224
+ )
225
+ )
226
+
227
+ def forward(self, dec_inp, mask=None, conditioning=0):
228
+ inp = dec_inp.transpose(1, 2)
229
+ mask = mask.bool().squeeze(1).unsqueeze(2)
230
+ # if self.word_emb is None:
231
+ # inp = dec_inp
232
+ # mask = sequence_mask(seq_lens, inp.shape[1], device=seq_lens.device, dtype=seq_lens.dtype).unsqueeze(2)
233
+ # else:
234
+ # inp = self.word_emb(dec_inp)
235
+ # # [bsz x L x 1]
236
+ # mask = (dec_inp != self.padding_idx).unsqueeze(2)
237
+
238
+ if not self.rel_attention:
239
+ pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
240
+ pos_emb = self.pos_emb(pos_seq) * mask
241
+ else:
242
+ pos_emb = 0
243
+
244
+ out = self.drop(inp + pos_emb + conditioning)
245
+
246
+ for layer in self.layers:
247
+ out = layer(out, mask=mask)
248
+
249
+ # out = self.drop(out)
250
+ return rearrange(out, "b l h -> b h l")
diff_ttsg/models/diff_ttsg.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Any
4
+
5
+ import torch
6
+ from lightning import LightningModule
7
+
8
+ import diff_ttsg.utils.monotonic_align as monotonic_align
9
+ from diff_ttsg import utils
10
+ from diff_ttsg.models.components.diffusion import Diffusion, Diffusion_Motion
11
+ from diff_ttsg.models.components.text_encoder import (MuMotionEncoder,
12
+ TextEncoder)
13
+ from diff_ttsg.utils.model import (denormalize, duration_loss,
14
+ fix_len_compatibility, generate_path,
15
+ sequence_mask)
16
+ from diff_ttsg.utils.utils import plot_tensor
17
+
18
+ log = utils.get_pylogger(__name__)
19
+
20
+ class Diff_TTSG(LightningModule):
21
+ def __init__(
22
+ self,
23
+ n_vocab,
24
+ n_spks,
25
+ spk_emb_dim,
26
+ n_enc_channels,
27
+ filter_channels,
28
+ filter_channels_dp,
29
+ n_heads,
30
+ n_enc_layers,
31
+ enc_kernel,
32
+ enc_dropout,
33
+ window_size,
34
+ n_feats,
35
+ n_motions,
36
+ dec_dim,
37
+ beta_min,
38
+ beta_max,
39
+ pe_scale,
40
+ mu_motion_encoder_params,
41
+ motion_reduction_factor,
42
+ motion_decoder_channels,
43
+ data_statistics,
44
+ out_size,
45
+ only_speech=False,
46
+ encoder_type="default",
47
+ optimizer=None
48
+ ):
49
+ super(Diff_TTSG, self).__init__()
50
+
51
+ self.save_hyperparameters(logger=False)
52
+
53
+ self.n_vocab = n_vocab
54
+ self.n_spks = n_spks
55
+ self.spk_emb_dim = spk_emb_dim
56
+ self.n_enc_channels = n_enc_channels
57
+ self.filter_channels = filter_channels
58
+ self.filter_channels_dp = filter_channels_dp
59
+ self.n_heads = n_heads
60
+ self.n_enc_layers = n_enc_layers
61
+ self.enc_kernel = enc_kernel
62
+ self.enc_dropout = enc_dropout
63
+ self.window_size = window_size
64
+ self.n_feats = n_feats
65
+ self.n_motions = n_motions
66
+ self.dec_dim = dec_dim
67
+ self.beta_min = beta_min
68
+ self.beta_max = beta_max
69
+ self.pe_scale = pe_scale
70
+ self.generate_motion = not only_speech
71
+ self.motion_reduction_factor = motion_reduction_factor
72
+ self.out_size = out_size
73
+ self.mu_diffusion_channels = motion_decoder_channels
74
+
75
+ if n_spks > 1:
76
+ self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
77
+ self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
78
+ filter_channels, filter_channels_dp, n_heads,
79
+ n_enc_layers, enc_kernel, enc_dropout, window_size, encoder_type=encoder_type)
80
+ self.decoder = Diffusion(n_feats, dec_dim, n_spks, spk_emb_dim, beta_min, beta_max, pe_scale)
81
+
82
+ if self.generate_motion:
83
+ self.motion_prior_loss = mu_motion_encoder_params.pop('prior_loss', True)
84
+ self.mu_motion_encoder = MuMotionEncoder(
85
+ input_channels=n_feats,
86
+ output_channels=n_motions,
87
+ **mu_motion_encoder_params
88
+ )
89
+ self.decoder_motion = Diffusion_Motion(
90
+ in_channels=n_motions,
91
+ motion_decoder_channels=motion_decoder_channels,
92
+ beta_min=beta_min,
93
+ beta_max=beta_max,
94
+ )
95
+
96
+ self.update_data_statistics(data_statistics)
97
+
98
+ def update_data_statistics(self, data_statistics):
99
+ if data_statistics is None:
100
+ data_statistics = {
101
+ 'mel_mean': 0.0,
102
+ 'mel_std': 1.0,
103
+ 'motion_mean': 0.0,
104
+ 'motion_std': 1.0,
105
+ }
106
+
107
+ self.register_buffer('mel_mean', torch.tensor(data_statistics['mel_mean']))
108
+ self.register_buffer('mel_std', torch.tensor(data_statistics['mel_std']))
109
+ self.register_buffer('motion_mean', torch.tensor(data_statistics['motion_mean']))
110
+ self.register_buffer('motion_std', torch.tensor(data_statistics['motion_std']))
111
+
112
+ @torch.inference_mode()
113
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, length_scale=1.0):
114
+ """
115
+ Generates mel-spectrogram from text. Returns:
116
+ 1. encoder outputs
117
+ 2. decoder outputs
118
+ 3. generated alignment
119
+
120
+ Args:
121
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
122
+ x_lengths (torch.Tensor): lengths of texts in batch.
123
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
124
+ temperature (float, optional): controls variance of terminal distribution.
125
+ stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
126
+ Usually, does not provide synthesis improvements.
127
+ length_scale (float, optional): controls speech pace.
128
+ Increase value to slow down generated speech and vice versa.
129
+ """
130
+ if isinstance(n_timesteps, dict):
131
+ n_timestep_mel = n_timesteps['mel']
132
+ n_timestep_motion = n_timesteps['motion']
133
+ else:
134
+ n_timestep_mel = n_timesteps
135
+ n_timestep_motion = n_timesteps
136
+
137
+ if isinstance(temperature, dict):
138
+ temperature_mel = temperature['mel']
139
+ temperature_motion = temperature['motion']
140
+ else:
141
+ temperature_mel = temperature
142
+ temperature_motion = temperature
143
+
144
+ if self.n_spks > 1:
145
+ # Get speaker embedding
146
+ spk = self.spk_emb(spk)
147
+
148
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
149
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
150
+
151
+ w = torch.exp(logw) * x_mask
152
+ w_ceil = torch.ceil(w) * length_scale
153
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
154
+ y_max_length = int(y_lengths.max())
155
+ y_max_length_ = fix_len_compatibility(y_max_length)
156
+
157
+ # Using obtained durations `w` construct alignment map `attn`
158
+ y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
159
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
160
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
161
+
162
+ # Align encoded text and get mu_y
163
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
164
+ mu_y = mu_y.transpose(1, 2)
165
+ encoder_outputs = mu_y[:, :, :y_max_length]
166
+
167
+
168
+ # Sample latent representation from terminal distribution N(mu_y, I)
169
+ z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature_mel
170
+ # Generate sample by performing reverse dynamics
171
+ decoder_outputs = self.decoder(z, y_mask, mu_y, n_timestep_mel, stoc, spk)
172
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
173
+
174
+ if self.generate_motion:
175
+ mu_y_motion = mu_y[:, :, ::self.motion_reduction_factor]
176
+ y_motion_mask = y_mask[:, :, ::self.motion_reduction_factor]
177
+ mu_y_motion = self.mu_motion_encoder(mu_y_motion, y_motion_mask)
178
+ encoder_outputs_motion = mu_y_motion[:, :, :y_max_length]
179
+ # sample latent representation from terminal distribution N(mu_y_motion, I)
180
+ z_motion = mu_y_motion + torch.randn_like(mu_y_motion, device=mu_y_motion.device) / temperature_motion
181
+ # Generate sample by performing reverse dynamics
182
+ decoder_outputs_motion = self.decoder_motion(z_motion, y_motion_mask, mu_y_motion, n_timestep_motion, stoc, spk)
183
+ decoder_outputs_motion = decoder_outputs_motion[:, :, :y_max_length]
184
+ else:
185
+ decoder_outputs_motion = None
186
+ encoder_outputs_motion = None
187
+
188
+ return {
189
+ 'encoder_outputs_mel': encoder_outputs,
190
+ 'decoder_outputs_mel': decoder_outputs,
191
+ 'encoder_outputs_motion': encoder_outputs_motion,
192
+ 'decoder_outputs_motion': decoder_outputs_motion,
193
+ 'attn': attn[:, :, :y_max_length],
194
+ 'mel': denormalize(decoder_outputs, self.mel_mean, self.mel_std),
195
+ 'motion': denormalize(decoder_outputs_motion, self.motion_mean, self.motion_std) if self.generate_motion else None,
196
+ }
197
+
198
+ def forward(self, x, x_lengths, y, y_lengths, y_motion, spk=None, out_size=None):
199
+ """
200
+ Computes 3 losses:
201
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
202
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
203
+ 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
204
+
205
+ Args:
206
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
207
+ x_lengths (torch.Tensor): lengths of texts in batch.
208
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
209
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
210
+ out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
211
+ Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
212
+ """
213
+ if self.n_spks > 1:
214
+ # Get speaker embedding
215
+ spk = self.spk_emb(spk)
216
+
217
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
218
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
219
+ y_max_length = y.shape[-1]
220
+
221
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
222
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
223
+
224
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
225
+ with torch.no_grad():
226
+ const = -0.5 * math.log(2 * math.pi) * self.n_feats
227
+ factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
228
+ y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
229
+ y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
230
+ mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
231
+ log_prior = y_square - y_mu_double + mu_square + const
232
+
233
+ attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
234
+ attn = attn.detach()
235
+
236
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
237
+ logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
238
+ dur_loss = duration_loss(logw, logw_, x_lengths)
239
+
240
+ # Cut a small segment of mel-spectrogram in order to increase batch size
241
+ if not isinstance(out_size, type(None)):
242
+ max_offset = (y_lengths - out_size).clamp(0) # cut a random segment of size `out_size` from each sample in batch max_offset: [758, 160, 773]
243
+ offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) # offset ranges for each sample in batch offset_ranges: [(0, 758), (0, 160), (0, 773)]
244
+ out_offset = torch.LongTensor([
245
+ torch.tensor(random.choice(range(start, end)) if end > start else 0)
246
+ for start, end in offset_ranges
247
+ ]).to(y_lengths)
248
+ attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
249
+ y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
250
+
251
+ if self.generate_motion:
252
+ y_motion_cut = torch.zeros(y_motion.shape[0], self.n_motions, out_size, dtype=y_motion.dtype, device=y_motion.device)
253
+
254
+ y_cut_lengths = []
255
+ for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
256
+ y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
257
+ y_cut_lengths.append(y_cut_length)
258
+ cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
259
+ y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
260
+ if self.generate_motion:
261
+ y_motion_cut[i, :, :y_cut_length] = y_motion[i, :, cut_lower:cut_upper]
262
+
263
+ attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
264
+ y_cut_lengths = torch.LongTensor(y_cut_lengths)
265
+ y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
266
+
267
+ attn = attn_cut
268
+ y = y_cut
269
+ if self.generate_motion:
270
+ y_motion = y_motion_cut
271
+
272
+ y_mask = y_cut_mask
273
+
274
+ # Align encoded text with mel-spectrogram and get mu_y segment
275
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
276
+ mu_y = mu_y.transpose(1, 2)
277
+
278
+
279
+
280
+ # Compute loss of score-based decoder
281
+ diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk)
282
+ if self.generate_motion:
283
+ # Reduce motion features
284
+ mu_y_motion = mu_y[:, :, ::self.motion_reduction_factor]
285
+ y_motion_mask = y_mask[:, :, ::self.motion_reduction_factor]
286
+ y_motion = y_motion[:, :, ::self.motion_reduction_factor]
287
+
288
+ mu_y_motion = self.mu_motion_encoder(mu_y_motion, y_motion_mask)
289
+ diff_loss_motion, xt_motion = self.decoder_motion.compute_loss(y_motion, y_motion_mask, mu_y_motion, spk)
290
+ else:
291
+ diff_loss_motion = 0
292
+
293
+ # Compute loss between aligned encoder outputs and mel-spectrogram
294
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
295
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
296
+
297
+ if self.generate_motion and self.motion_prior_loss:
298
+ prior_loss_motion = torch.sum(0.5 * ((y_motion - mu_y_motion) ** 2 + math.log(2 * math.pi)) * y_motion_mask)
299
+ prior_loss_motion = prior_loss_motion / (torch.sum(y_motion_mask) * self.n_motions)
300
+ else:
301
+ prior_loss_motion = 0
302
+
303
+ return dur_loss, prior_loss + prior_loss_motion, diff_loss + diff_loss_motion
304
+
305
+
306
+ def configure_optimizers(self) -> Any:
307
+ optimizer = self.hparams.optimizer(params=self.parameters())
308
+ return {'optimizer': optimizer}
309
+
310
+ def get_losses(self, batch):
311
+ pass
312
+ x, x_lengths = batch['x'], batch['x_lengths']
313
+ y, y_lengths = batch['y'], batch['y_lengths']
314
+ y_motion = batch['y_motion']
315
+ dur_loss, prior_loss, diff_loss = self(x, x_lengths, y, y_lengths, y_motion, out_size=self.out_size)
316
+ return {
317
+ 'dur_loss': dur_loss,
318
+ 'prior_loss': prior_loss,
319
+ 'diff_loss': diff_loss,
320
+ }
321
+
322
+
323
+
324
+ def training_step(self, batch: Any, batch_idx: int):
325
+ loss_dict = self.get_losses(batch)
326
+ self.log('step', float(self.global_step), on_step=True, on_epoch=True, logger=True, sync_dist=True)
327
+
328
+ self.log('sub_loss/train_dur_loss', loss_dict['dur_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
329
+ self.log('sub_loss/train_prior_loss', loss_dict['prior_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
330
+ self.log('sub_loss/train_diff_loss', loss_dict['diff_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
331
+
332
+ total_loss = sum(loss_dict.values())
333
+ self.log('loss/train', total_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
334
+
335
+ return {'loss': total_loss, 'log': loss_dict }
336
+
337
+ def validation_step(self, batch: Any, batch_idx: int):
338
+ loss_dict = self.get_losses(batch)
339
+ self.log('sub_loss/val_dur_loss', loss_dict['dur_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
340
+ self.log('sub_loss/val_prior_loss', loss_dict['prior_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
341
+ self.log('sub_loss/val_diff_loss', loss_dict['diff_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
342
+
343
+ total_loss = sum(loss_dict.values())
344
+ self.log('loss/val', total_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
345
+
346
+ return total_loss
347
+
348
+ def on_validation_end(self) -> None:
349
+ if self.trainer.is_global_zero:
350
+ one_batch = next(iter(self.trainer.val_dataloaders))
351
+ if self.current_epoch == 0:
352
+ log.debug("Plotting original samples")
353
+ for i in range(4):
354
+ y = one_batch['y'][i].unsqueeze(0).to(self.device)
355
+ y_motion = one_batch['y_motion'][i].unsqueeze(0).to(self.device)
356
+ self.logger.experiment.add_image(f'original/mel_{i}', plot_tensor(y.squeeze().cpu()), self.current_epoch, dataformats='HWC')
357
+ if self.generate_motion:
358
+ self.logger.experiment.add_image(f'original/mel_{i}', plot_tensor(y_motion.squeeze().cpu()), self.current_epoch, dataformats='HWC')
359
+
360
+ log.debug(f'Synthesising...')
361
+ for i in range(4):
362
+ x = one_batch['x'][i].unsqueeze(0).to(self.device)
363
+ x_lengths = one_batch['x_lengths'][i].unsqueeze(0).to(self.device)
364
+ output = self.synthesise(x, x_lengths, n_timesteps=20)
365
+ y_enc, y_dec = output['encoder_outputs_mel'], output['decoder_outputs_mel']
366
+ y_motion_enc, y_motion_dec, attn = output['encoder_outputs_motion'], output['decoder_outputs_motion'], output['attn']
367
+ self.logger.experiment.add_image(f'generated_enc/{i}', plot_tensor(y_enc.squeeze().cpu()), self.current_epoch, dataformats='HWC')
368
+ self.logger.experiment.add_image(f'generated_dec/{i}', plot_tensor(y_dec.squeeze().cpu()), self.current_epoch, dataformats='HWC')
369
+ if self.generate_motion:
370
+ self.logger.experiment.add_image(f'generated_enc_motion/{i}', plot_tensor(y_motion_enc.squeeze().cpu()), self.current_epoch, dataformats='HWC')
371
+ self.logger.experiment.add_image(f'generated_dec_motion/{i}', plot_tensor(y_motion_dec.squeeze().cpu()), self.current_epoch, dataformats='HWC')
372
+
373
+ self.logger.experiment.add_image(f'alignment/{i}', plot_tensor(attn.squeeze().cpu()), self.current_epoch, dataformats='HWC')
374
+
375
+
376
+
diff_ttsg/models/mnist_module.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from lightning import LightningModule
5
+ from torchmetrics import MaxMetric, MeanMetric
6
+ from torchmetrics.classification.accuracy import Accuracy
7
+
8
+
9
+ class MNISTLitModule(LightningModule):
10
+ """Example of LightningModule for MNIST classification.
11
+
12
+ A LightningModule organizes your PyTorch code into 6 sections:
13
+ - Initialization (__init__)
14
+ - Train Loop (training_step)
15
+ - Validation loop (validation_step)
16
+ - Test loop (test_step)
17
+ - Prediction Loop (predict_step)
18
+ - Optimizers and LR Schedulers (configure_optimizers)
19
+
20
+ Docs:
21
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ net: torch.nn.Module,
27
+ optimizer: torch.optim.Optimizer,
28
+ scheduler: torch.optim.lr_scheduler,
29
+ ):
30
+ super().__init__()
31
+
32
+ # this line allows to access init params with 'self.hparams' attribute
33
+ # also ensures init params will be stored in ckpt
34
+ self.save_hyperparameters(logger=False)
35
+
36
+ self.net = net
37
+
38
+ # loss function
39
+ self.criterion = torch.nn.CrossEntropyLoss()
40
+
41
+ # metric objects for calculating and averaging accuracy across batches
42
+ self.train_acc = Accuracy(task="multiclass", num_classes=10)
43
+ self.val_acc = Accuracy(task="multiclass", num_classes=10)
44
+ self.test_acc = Accuracy(task="multiclass", num_classes=10)
45
+
46
+ # for averaging loss across batches
47
+ self.train_loss = MeanMetric()
48
+ self.val_loss = MeanMetric()
49
+ self.test_loss = MeanMetric()
50
+
51
+ # for tracking best so far validation accuracy
52
+ self.val_acc_best = MaxMetric()
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ return self.net(x)
56
+
57
+ def on_train_start(self):
58
+ # by default lightning executes validation step sanity checks before training starts,
59
+ # so it's worth to make sure validation metrics don't store results from these checks
60
+ self.val_loss.reset()
61
+ self.val_acc.reset()
62
+ self.val_acc_best.reset()
63
+
64
+ def model_step(self, batch: Any):
65
+ x, y = batch
66
+ logits = self.forward(x)
67
+ loss = self.criterion(logits, y)
68
+ preds = torch.argmax(logits, dim=1)
69
+ return loss, preds, y
70
+
71
+ def training_step(self, batch: Any, batch_idx: int):
72
+ loss, preds, targets = self.model_step(batch)
73
+
74
+ # update and log metrics
75
+ self.train_loss(loss)
76
+ self.train_acc(preds, targets)
77
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
78
+ self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
79
+
80
+ # return loss or backpropagation will fail
81
+ return loss
82
+
83
+ def on_train_epoch_end(self):
84
+ pass
85
+
86
+ def validation_step(self, batch: Any, batch_idx: int):
87
+ loss, preds, targets = self.model_step(batch)
88
+
89
+ # update and log metrics
90
+ self.val_loss(loss)
91
+ self.val_acc(preds, targets)
92
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
93
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
94
+
95
+ def on_validation_epoch_end(self):
96
+ acc = self.val_acc.compute() # get current val acc
97
+ self.val_acc_best(acc) # update best so far val acc
98
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
99
+ # otherwise metric would be reset by lightning after each epoch
100
+ self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
101
+
102
+ def test_step(self, batch: Any, batch_idx: int):
103
+ loss, preds, targets = self.model_step(batch)
104
+
105
+ # update and log metrics
106
+ self.test_loss(loss)
107
+ self.test_acc(preds, targets)
108
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
109
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
110
+
111
+ def on_test_epoch_end(self):
112
+ pass
113
+
114
+ def configure_optimizers(self):
115
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
116
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
117
+
118
+ Examples:
119
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
120
+ """
121
+ optimizer = self.hparams.optimizer(params=self.parameters())
122
+ if self.hparams.scheduler is not None:
123
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
124
+ return {
125
+ "optimizer": optimizer,
126
+ "lr_scheduler": {
127
+ "scheduler": scheduler,
128
+ "monitor": "val/loss",
129
+ "interval": "epoch",
130
+ "frequency": 1,
131
+ },
132
+ }
133
+ return {"optimizer": optimizer}
134
+
135
+
136
+ if __name__ == "__main__":
137
+ _ = MNISTLitModule(None, None, None)
diff_ttsg/resources/cmu_dictionary ADDED
The diff for this file is too large to render. See raw diff
 
diff_ttsg/text/LICENSE ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CMUdict
2
+ -------
3
+
4
+ CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free
5
+ pronouncing dictionary of English, suitable for uses in speech
6
+ technology and is maintained by the Speech Group in the School of
7
+ Computer Science at Carnegie Mellon University.
8
+
9
+ The Carnegie Mellon Speech Group does not guarantee the accuracy of
10
+ this dictionary, nor its suitability for any specific purpose. In
11
+ fact, we expect a number of errors, omissions and inconsistencies to
12
+ remain in the dictionary. We intend to continually update the
13
+ dictionary by correction existing entries and by adding new ones. From
14
+ time to time a new major version will be released.
15
+
16
+ We welcome input from users: Please send email to Alex Rudnicky
17
+ (air+cmudict@cs.cmu.edu).
18
+
19
+ The Carnegie Mellon Pronouncing Dictionary, in its current and
20
+ previous versions is Copyright (C) 1993-2014 by Carnegie Mellon
21
+ University. Use of this dictionary for any research or commercial
22
+ purpose is completely unrestricted. If you make use of or
23
+ redistribute this material we request that you acknowledge its
24
+ origin in your descriptions.
25
+
26
+ If you add words to or correct words in your version of this
27
+ dictionary, we would appreciate it if you could send these additions
28
+ and corrections to us (air+cmudict@cs.cmu.edu) for consideration in a
29
+ subsequent version. All submissions will be reviewed and approved by
30
+ the current maintainer, Alex Rudnicky at Carnegie Mellon.
diff_ttsg/text/__init__.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+ from diff_ttsg.text import cleaners
6
+ from diff_ttsg.text.symbols import symbols
7
+
8
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
9
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
10
+
11
+ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
12
+
13
+
14
+ def get_arpabet(word, dictionary):
15
+ word_arpabet = dictionary.lookup(word)
16
+ if word_arpabet is not None:
17
+ return "{" + word_arpabet[0] + "}"
18
+ else:
19
+ return word
20
+
21
+
22
+ def text_to_sequence(text, cleaner_names=["english_cleaners"], dictionary=None):
23
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
24
+
25
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
26
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
27
+
28
+ Args:
29
+ text: string to convert to a sequence
30
+ cleaner_names: names of the cleaner functions to run the text through
31
+ dictionary: arpabet class with arpabet dictionary
32
+
33
+ Returns:
34
+ List of integers corresponding to the symbols in the text
35
+ '''
36
+ sequence = []
37
+ space = _symbols_to_sequence(' ')
38
+ # Check for curly braces and treat their contents as ARPAbet:
39
+ while len(text):
40
+ m = _curly_re.match(text)
41
+ if not m:
42
+ clean_text = _clean_text(text, cleaner_names)
43
+ if dictionary is not None:
44
+ clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
45
+ for i in range(len(clean_text)):
46
+ t = clean_text[i]
47
+ if t.startswith("{"):
48
+ sequence += _arpabet_to_sequence(t[1:-1])
49
+ else:
50
+ sequence += _symbols_to_sequence(t)
51
+ sequence += space
52
+ else:
53
+ sequence += _symbols_to_sequence(clean_text)
54
+ break
55
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
56
+ sequence += _arpabet_to_sequence(m.group(2))
57
+ text = m.group(3)
58
+
59
+ # remove trailing space
60
+ if dictionary is not None:
61
+ sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
62
+ return sequence
63
+
64
+
65
+ def sequence_to_text(sequence):
66
+ '''Converts a sequence of IDs back to a string'''
67
+ result = ''
68
+ for symbol_id in sequence:
69
+ if symbol_id in _id_to_symbol:
70
+ s = _id_to_symbol[symbol_id]
71
+ # Enclose ARPAbet back in curly braces:
72
+ if len(s) > 1 and s[0] == '@':
73
+ s = '{%s}' % s[1:]
74
+ result += s
75
+ return result.replace('}{', ' ')
76
+
77
+
78
+ def _clean_text(text, cleaner_names):
79
+ for name in cleaner_names:
80
+ cleaner = getattr(cleaners, name)
81
+ if not cleaner:
82
+ raise Exception('Unknown cleaner: %s' % name)
83
+ text = cleaner(text)
84
+ return text
85
+
86
+
87
+ def _symbols_to_sequence(symbols):
88
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
89
+
90
+
91
+ def _arpabet_to_sequence(text):
92
+ return _symbols_to_sequence(['@' + s for s in text.split()])
93
+
94
+
95
+ def _should_keep_symbol(s):
96
+ return s in _symbol_to_id and s != '_' and s != '~'
diff_ttsg/text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
diff_ttsg/text/__pycache__/cleaners.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
diff_ttsg/text/__pycache__/cmudict.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
diff_ttsg/text/__pycache__/numbers.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
diff_ttsg/text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (604 Bytes). View file
 
diff_ttsg/text/cleaners.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+ from unidecode import unidecode
5
+ from .numbers import normalize_numbers
6
+
7
+
8
+ _whitespace_re = re.compile(r'\s+')
9
+
10
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
11
+ ('mrs', 'misess'),
12
+ ('mr', 'mister'),
13
+ ('dr', 'doctor'),
14
+ ('st', 'saint'),
15
+ ('co', 'company'),
16
+ ('jr', 'junior'),
17
+ ('maj', 'major'),
18
+ ('gen', 'general'),
19
+ ('drs', 'doctors'),
20
+ ('rev', 'reverend'),
21
+ ('lt', 'lieutenant'),
22
+ ('hon', 'honorable'),
23
+ ('sgt', 'sergeant'),
24
+ ('capt', 'captain'),
25
+ ('esq', 'esquire'),
26
+ ('ltd', 'limited'),
27
+ ('col', 'colonel'),
28
+ ('ft', 'fort'),
29
+ ]]
30
+
31
+
32
+ def expand_abbreviations(text):
33
+ for regex, replacement in _abbreviations:
34
+ text = re.sub(regex, replacement, text)
35
+ return text
36
+
37
+
38
+ def expand_numbers(text):
39
+ return normalize_numbers(text)
40
+
41
+
42
+ def lowercase(text):
43
+ return text.lower()
44
+
45
+
46
+ def collapse_whitespace(text):
47
+ return re.sub(_whitespace_re, ' ', text)
48
+
49
+
50
+ def convert_to_ascii(text):
51
+ return unidecode(text)
52
+
53
+
54
+ def basic_cleaners(text):
55
+ text = lowercase(text)
56
+ text = collapse_whitespace(text)
57
+ return text
58
+
59
+
60
+ def transliteration_cleaners(text):
61
+ text = convert_to_ascii(text)
62
+ text = lowercase(text)
63
+ text = collapse_whitespace(text)
64
+ return text
65
+
66
+
67
+ def english_cleaners(text):
68
+ text = convert_to_ascii(text)
69
+ text = lowercase(text)
70
+ text = expand_numbers(text)
71
+ text = expand_abbreviations(text)
72
+ text = collapse_whitespace(text)
73
+ return text
diff_ttsg/text/cmudict.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+
6
+ valid_symbols = [
7
+ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8
+ 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9
+ 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10
+ 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11
+ 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12
+ 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13
+ 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14
+ ]
15
+
16
+ _valid_symbol_set = set(valid_symbols)
17
+
18
+
19
+ class CMUDict:
20
+ def __init__(self, file_or_path, keep_ambiguous=True):
21
+ if isinstance(file_or_path, str):
22
+ with open(file_or_path, encoding='latin-1') as f:
23
+ entries = _parse_cmudict(f)
24
+ else:
25
+ entries = _parse_cmudict(file_or_path)
26
+ if not keep_ambiguous:
27
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
28
+ self._entries = entries
29
+
30
+ def __len__(self):
31
+ return len(self._entries)
32
+
33
+ def lookup(self, word):
34
+ return self._entries.get(word.upper())
35
+
36
+
37
+ _alt_re = re.compile(r'\([0-9]+\)')
38
+
39
+
40
+ def _parse_cmudict(file):
41
+ cmudict = {}
42
+ for line in file:
43
+ if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
44
+ parts = line.split(' ')
45
+ word = re.sub(_alt_re, '', parts[0])
46
+ pronunciation = _get_pronunciation(parts[1])
47
+ if pronunciation:
48
+ if word in cmudict:
49
+ cmudict[word].append(pronunciation)
50
+ else:
51
+ cmudict[word] = [pronunciation]
52
+ return cmudict
53
+
54
+
55
+ def _get_pronunciation(s):
56
+ parts = s.strip().split(' ')
57
+ for part in parts:
58
+ if part not in _valid_symbol_set:
59
+ return None
60
+ return ' '.join(parts)
diff_ttsg/text/numbers.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import inflect
4
+ import re
5
+
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
9
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
10
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
11
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
12
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
13
+ _number_re = re.compile(r'[0-9]+')
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(',', '')
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace('.', ' point ')
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split('.')
27
+ if len(parts) > 2:
28
+ return match + ' dollars'
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
33
+ cent_unit = 'cent' if cents == 1 else 'cents'
34
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
35
+ elif dollars:
36
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
37
+ return '%s %s' % (dollars, dollar_unit)
38
+ elif cents:
39
+ cent_unit = 'cent' if cents == 1 else 'cents'
40
+ return '%s %s' % (cents, cent_unit)
41
+ else:
42
+ return 'zero dollars'
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return 'two thousand'
54
+ elif num > 2000 and num < 2010:
55
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + ' hundred'
58
+ else:
59
+ return _inflect.number_to_words(num, andword='', zero='oh',
60
+ group=2).replace(', ', ' ')
61
+ else:
62
+ return _inflect.number_to_words(num, andword='')
63
+
64
+
65
+ def normalize_numbers(text):
66
+ text = re.sub(_comma_number_re, _remove_commas, text)
67
+ text = re.sub(_pounds_re, r'\1 pounds', text)
68
+ text = re.sub(_dollars_re, _expand_dollars, text)
69
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
70
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
71
+ text = re.sub(_number_re, _expand_number, text)
72
+ return text