Spaces:
Runtime error
Runtime error
Upload 238 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- app.py +138 -0
- audios/audio0.flac +0 -0
- audios/audio1.flac +0 -0
- audios/audio2.flac +0 -0
- audios/audio3.flac +0 -0
- audios/audio4.flac +0 -0
- audios/audio5.flac +0 -0
- audios/audio6.flac +0 -0
- checkpoints/.keep +0 -0
- checkpoints/checkpoint_0.pt +3 -0
- checkpoints/vocoder.pt +3 -0
- config.py +52 -0
- datas/__init__.py +0 -0
- datas/__pycache__/__init__.cpython-311.pyc +0 -0
- datas/__pycache__/dataset.cpython-311.pyc +0 -0
- datas/dataset.py +52 -0
- datas/sampler.py +121 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/dit.cpython-311.pyc +0 -0
- models/__pycache__/duration_predictor.cpython-311.pyc +0 -0
- models/__pycache__/estimator.cpython-311.pyc +0 -0
- models/__pycache__/flow_matching.cpython-311.pyc +0 -0
- models/__pycache__/model.cpython-310.pyc +0 -0
- models/__pycache__/model.cpython-311.pyc +0 -0
- models/__pycache__/reference_encoder.cpython-311.pyc +0 -0
- models/__pycache__/text_encoder.cpython-311.pyc +0 -0
- models/dit.py +205 -0
- models/duration_predictor.py +40 -0
- models/estimator.py +161 -0
- models/flow_matching.py +108 -0
- models/model.py +194 -0
- models/reference_encoder.py +93 -0
- models/text_encoder.py +49 -0
- monotonic_align/__init__.py +16 -0
- monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
- monotonic_align/__pycache__/__init__.cpython-311.pyc +0 -0
- monotonic_align/__pycache__/core.cpython-310.pyc +0 -0
- monotonic_align/__pycache__/core.cpython-311.pyc +0 -0
- monotonic_align/core.py +46 -0
- requirements.txt +12 -0
- text/LICENSE +19 -0
- text/__init__.py +71 -0
- text/__pycache__/__init__.cpython-310.pyc +0 -0
- text/__pycache__/__init__.cpython-311.pyc +0 -0
- text/__pycache__/cleaners.cpython-310.pyc +0 -0
- text/__pycache__/cleaners.cpython-311.pyc +0 -0
- text/__pycache__/english.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
37 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
38 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
39 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
40 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
41 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dataclasses import asdict
|
4 |
+
from text import symbols
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from utils.audio import LogMelSpectrogram
|
9 |
+
from config import ModelConfig, VocosConfig, MelConfig
|
10 |
+
from models.model import StableTTS
|
11 |
+
from vocos_pytorch.models.model import Vocos
|
12 |
+
from text.english import english_to_ipa2
|
13 |
+
from text import cleaned_text_to_sequence
|
14 |
+
from datas.dataset import intersperse
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import numpy as np
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
device = 'cpu'
|
22 |
+
|
23 |
+
@ torch.inference_mode()
|
24 |
+
def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
|
25 |
+
global last_checkpoint_path
|
26 |
+
if checkpoint_path != last_checkpoint_path:
|
27 |
+
tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
28 |
+
last_checkpoint_path = checkpoint_path
|
29 |
+
|
30 |
+
phonemizer = english_to_ipa2
|
31 |
+
|
32 |
+
# prepare input for tts model
|
33 |
+
x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
|
34 |
+
x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
|
35 |
+
waveform, sr = torchaudio.load(ref_audio)
|
36 |
+
if sr != sample_rate:
|
37 |
+
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
|
38 |
+
y = mel_extractor(waveform).to(device)
|
39 |
+
|
40 |
+
# inference
|
41 |
+
mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
|
42 |
+
audio = vocoder(mel)
|
43 |
+
|
44 |
+
# process output for gradio
|
45 |
+
audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
|
46 |
+
mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
|
47 |
+
return audio_output, mel_output
|
48 |
+
|
49 |
+
def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
|
50 |
+
tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
|
51 |
+
mel_extractor = LogMelSpectrogram(mel_config)
|
52 |
+
vocoder = Vocos(vocoder_config, mel_config)
|
53 |
+
# tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
|
54 |
+
tts_model.to(device)
|
55 |
+
tts_model.eval()
|
56 |
+
vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
|
57 |
+
vocoder.to(device)
|
58 |
+
vocoder.eval()
|
59 |
+
return tts_model, mel_extractor, vocoder
|
60 |
+
|
61 |
+
def plot_mel_spectrogram(mel_spectrogram):
|
62 |
+
fig, ax = plt.subplots(figsize=(20, 8))
|
63 |
+
ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
|
64 |
+
plt.axis('off')
|
65 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
|
66 |
+
return fig
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
tts_model_config = ModelConfig()
|
71 |
+
mel_config = MelConfig()
|
72 |
+
vocoder_config = VocosConfig()
|
73 |
+
|
74 |
+
tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
|
75 |
+
vocoder_checkpoint_path = './checkpoints/vocoder.pt'
|
76 |
+
|
77 |
+
global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
|
78 |
+
sample_rate = mel_config.sample_rate
|
79 |
+
last_checkpoint_path = None
|
80 |
+
tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
|
81 |
+
|
82 |
+
tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
|
83 |
+
audios = list(Path('./audios').rglob('*.wav')) + list(Path('./audios').rglob('*.flac'))
|
84 |
+
|
85 |
+
# gradio wabui
|
86 |
+
gui_title = 'StableTTS'
|
87 |
+
gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
|
88 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column():
|
92 |
+
gr.Markdown(f"# {gui_title}")
|
93 |
+
gr.Markdown(gui_description)
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
input_text_gr = gr.Textbox(
|
98 |
+
label="Input Text",
|
99 |
+
info="One or two sentences at a time is better. Up to 200 text characters.",
|
100 |
+
value="Today I want to tell you three stories from my life. That's it. No big deal. Just three stories.",
|
101 |
+
)
|
102 |
+
|
103 |
+
ref_audio_gr = gr.Dropdown(
|
104 |
+
label='reference audio',
|
105 |
+
choices=audios,
|
106 |
+
value = 0
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
checkpoint_gr = gr.Dropdown(
|
111 |
+
label='checkpoint',
|
112 |
+
choices=tts_checkpoint_path,
|
113 |
+
value = 0
|
114 |
+
)
|
115 |
+
|
116 |
+
step_gr = gr.Slider(
|
117 |
+
label='Step',
|
118 |
+
minimum=1,
|
119 |
+
maximum=40,
|
120 |
+
value=8,
|
121 |
+
step=1
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
126 |
+
|
127 |
+
with gr.Column():
|
128 |
+
mel_gr = gr.Plot(label="Mel Visual")
|
129 |
+
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
130 |
+
|
131 |
+
tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
|
132 |
+
|
133 |
+
demo.queue()
|
134 |
+
demo.launch(debug=True, show_api=True)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
main()
|
audios/audio0.flac
ADDED
Binary file (84 kB). View file
|
|
audios/audio1.flac
ADDED
Binary file (151 kB). View file
|
|
audios/audio2.flac
ADDED
Binary file (318 kB). View file
|
|
audios/audio3.flac
ADDED
Binary file (162 kB). View file
|
|
audios/audio4.flac
ADDED
Binary file (260 kB). View file
|
|
audios/audio5.flac
ADDED
Binary file (361 kB). View file
|
|
audios/audio6.flac
ADDED
Binary file (99.1 kB). View file
|
|
checkpoints/.keep
ADDED
File without changes
|
checkpoints/checkpoint_0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c473c5dc40bf87e8472f8688790cb20a2ec10494fa08cb710d657cf1f892d44
|
3 |
+
size 37552627
|
checkpoints/vocoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e180a1df6ca0a9e382e0915b0b1984aecfb63397c4ee21f12857447c2d76d29a
|
3 |
+
size 56666508
|
config.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class MelConfig:
|
5 |
+
sample_rate: int = 44100
|
6 |
+
n_fft: int = 2048
|
7 |
+
win_length: int = 2048
|
8 |
+
hop_length: int = 512
|
9 |
+
f_min: float = 0.0
|
10 |
+
f_max: float = None
|
11 |
+
pad: int = 0
|
12 |
+
n_mels: int = 128
|
13 |
+
power: float = 1.0
|
14 |
+
normalized: bool = False
|
15 |
+
center: bool = False
|
16 |
+
pad_mode: str = "reflect"
|
17 |
+
mel_scale: str = "htk"
|
18 |
+
|
19 |
+
def __post_init__(self):
|
20 |
+
if self.pad == 0:
|
21 |
+
self.pad = (self.n_fft - self.hop_length) // 2
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ModelConfig:
|
25 |
+
hidden_channels: int = 192
|
26 |
+
filter_channels: int = 512
|
27 |
+
n_heads: int = 2
|
28 |
+
n_enc_layers: int = 3
|
29 |
+
n_dec_layers: int = 2
|
30 |
+
kernel_size: int = 3
|
31 |
+
p_dropout: int = 0.1
|
32 |
+
gin_channels: int = 192
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class TrainConfig:
|
36 |
+
train_dataset_path: str = 'filelists/filelist.json'
|
37 |
+
test_dataset_path: str = 'filelists/filelist.json'
|
38 |
+
batch_size: int = 52
|
39 |
+
learning_rate: float = 1e-4
|
40 |
+
num_epochs: int = 10000
|
41 |
+
model_save_path: str = './checkpoints'
|
42 |
+
log_dir: str = './runs'
|
43 |
+
log_interval: int = 128
|
44 |
+
save_interval: int = 15
|
45 |
+
warmup_steps: int = 200
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class VocosConfig:
|
49 |
+
input_channels: int = 128
|
50 |
+
dim: int = 512
|
51 |
+
intermediate_dim: int = 1536
|
52 |
+
num_layers: int = 8
|
datas/__init__.py
ADDED
File without changes
|
datas/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (154 Bytes). View file
|
|
datas/__pycache__/dataset.cpython-311.pyc
ADDED
Binary file (4.57 kB). View file
|
|
datas/dataset.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from text import cleaned_text_to_sequence
|
8 |
+
|
9 |
+
def intersperse(lst, item):
|
10 |
+
result = [item] * (len(lst) * 2 + 1)
|
11 |
+
result[1::2] = lst
|
12 |
+
return result
|
13 |
+
|
14 |
+
class StableDataset(Dataset):
|
15 |
+
def __init__(self, filelist_path, hop_length):
|
16 |
+
self.filelist_path = filelist_path
|
17 |
+
self.hop_length = hop_length
|
18 |
+
|
19 |
+
self._load_filelist(filelist_path)
|
20 |
+
|
21 |
+
def _load_filelist(self, filelist_path):
|
22 |
+
filelist, lengths = [], []
|
23 |
+
with open(filelist_path, 'r', encoding='utf-8') as f:
|
24 |
+
for line in f:
|
25 |
+
line = json.loads(line.strip())
|
26 |
+
filelist.append((line['mel_path'], line['phone']))
|
27 |
+
lengths.append(os.path.getsize(line['audio_path']) // (2 * self.hop_length))
|
28 |
+
|
29 |
+
self.filelist = filelist
|
30 |
+
self.lengths = lengths
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.filelist)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
mel_path, phone = self.filelist[idx]
|
37 |
+
mel = torch.load(mel_path, map_location='cpu')
|
38 |
+
phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
|
39 |
+
return mel, phone
|
40 |
+
|
41 |
+
def collate_fn(batch):
|
42 |
+
texts = [item[1] for item in batch]
|
43 |
+
mels = [item[0] for item in batch]
|
44 |
+
|
45 |
+
text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
|
46 |
+
mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
|
47 |
+
|
48 |
+
# pad to the same length
|
49 |
+
texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
|
50 |
+
mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
|
51 |
+
|
52 |
+
return texts_padded, text_lengths, mels_padded, mel_lengths
|
datas/sampler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
|
4 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
5 |
+
"""
|
6 |
+
Maintain similar input lengths in a batch.
|
7 |
+
Length groups are specified by boundaries.
|
8 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
9 |
+
|
10 |
+
It removes samples which are not included in the boundaries.
|
11 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dataset,
|
17 |
+
batch_size,
|
18 |
+
boundaries,
|
19 |
+
num_replicas=None,
|
20 |
+
rank=None,
|
21 |
+
shuffle=True,
|
22 |
+
):
|
23 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
24 |
+
self.lengths = dataset.lengths
|
25 |
+
self.batch_size = batch_size
|
26 |
+
self.boundaries = boundaries
|
27 |
+
|
28 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
29 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
30 |
+
self.num_samples = self.total_size // self.num_replicas
|
31 |
+
|
32 |
+
def _create_buckets(self):
|
33 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
34 |
+
for i in range(len(self.lengths)):
|
35 |
+
length = self.lengths[i]
|
36 |
+
idx_bucket = self._bisect(length)
|
37 |
+
if idx_bucket != -1:
|
38 |
+
buckets[idx_bucket].append(i)
|
39 |
+
|
40 |
+
for i in range(len(buckets) - 1, 0, -1):
|
41 |
+
# for i in range(len(buckets) - 1, -1, -1):
|
42 |
+
if len(buckets[i]) == 0:
|
43 |
+
buckets.pop(i)
|
44 |
+
self.boundaries.pop(i + 1)
|
45 |
+
|
46 |
+
num_samples_per_bucket = []
|
47 |
+
for i in range(len(buckets)):
|
48 |
+
len_bucket = len(buckets[i])
|
49 |
+
total_batch_size = self.num_replicas * self.batch_size
|
50 |
+
rem = (
|
51 |
+
total_batch_size - (len_bucket % total_batch_size)
|
52 |
+
) % total_batch_size
|
53 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
54 |
+
return buckets, num_samples_per_bucket
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
# deterministically shuffle based on epoch
|
58 |
+
g = torch.Generator()
|
59 |
+
g.manual_seed(self.epoch)
|
60 |
+
|
61 |
+
indices = []
|
62 |
+
if self.shuffle:
|
63 |
+
for bucket in self.buckets:
|
64 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
65 |
+
else:
|
66 |
+
for bucket in self.buckets:
|
67 |
+
indices.append(list(range(len(bucket))))
|
68 |
+
|
69 |
+
batches = []
|
70 |
+
for i in range(len(self.buckets)):
|
71 |
+
bucket = self.buckets[i]
|
72 |
+
len_bucket = len(bucket)
|
73 |
+
ids_bucket = indices[i]
|
74 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
75 |
+
|
76 |
+
# add extra samples to make it evenly divisible
|
77 |
+
rem = num_samples_bucket - len_bucket
|
78 |
+
ids_bucket = (
|
79 |
+
ids_bucket
|
80 |
+
+ ids_bucket * (rem // len_bucket)
|
81 |
+
+ ids_bucket[: (rem % len_bucket)]
|
82 |
+
)
|
83 |
+
|
84 |
+
# subsample
|
85 |
+
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
86 |
+
|
87 |
+
# batching
|
88 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
89 |
+
batch = [
|
90 |
+
bucket[idx]
|
91 |
+
for idx in ids_bucket[
|
92 |
+
j * self.batch_size : (j + 1) * self.batch_size
|
93 |
+
]
|
94 |
+
]
|
95 |
+
batches.append(batch)
|
96 |
+
|
97 |
+
if self.shuffle:
|
98 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
99 |
+
batches = [batches[i] for i in batch_ids]
|
100 |
+
self.batches = batches
|
101 |
+
|
102 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
103 |
+
return iter(self.batches)
|
104 |
+
|
105 |
+
def _bisect(self, x, lo=0, hi=None):
|
106 |
+
if hi is None:
|
107 |
+
hi = len(self.boundaries) - 1
|
108 |
+
|
109 |
+
if hi > lo:
|
110 |
+
mid = (hi + lo) // 2
|
111 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
112 |
+
return mid
|
113 |
+
elif x <= self.boundaries[mid]:
|
114 |
+
return self._bisect(x, lo, mid)
|
115 |
+
else:
|
116 |
+
return self._bisect(x, mid + 1, hi)
|
117 |
+
else:
|
118 |
+
return -1
|
119 |
+
|
120 |
+
def __len__(self):
|
121 |
+
return self.num_samples // self.batch_size
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (135 Bytes). View file
|
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (155 Bytes). View file
|
|
models/__pycache__/dit.cpython-311.pyc
ADDED
Binary file (13.7 kB). View file
|
|
models/__pycache__/duration_predictor.cpython-311.pyc
ADDED
Binary file (3.2 kB). View file
|
|
models/__pycache__/estimator.cpython-311.pyc
ADDED
Binary file (12.7 kB). View file
|
|
models/__pycache__/flow_matching.cpython-311.pyc
ADDED
Binary file (5.92 kB). View file
|
|
models/__pycache__/model.cpython-310.pyc
ADDED
Binary file (6.49 kB). View file
|
|
models/__pycache__/model.cpython-311.pyc
ADDED
Binary file (11.6 kB). View file
|
|
models/__pycache__/reference_encoder.cpython-311.pyc
ADDED
Binary file (5.32 kB). View file
|
|
models/__pycache__/text_encoder.cpython-311.pyc
ADDED
Binary file (4.22 kB). View file
|
|
models/dit.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# References:
|
2 |
+
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
|
3 |
+
# https://github.com/jaywalnut310/vits/blob/main/attentions.py
|
4 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
class FFN(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
|
12 |
+
super().__init__()
|
13 |
+
self.in_channels = in_channels
|
14 |
+
self.out_channels = out_channels
|
15 |
+
self.filter_channels = filter_channels
|
16 |
+
self.kernel_size = kernel_size
|
17 |
+
self.p_dropout = p_dropout
|
18 |
+
self.gin_channels = gin_channels
|
19 |
+
|
20 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
21 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
22 |
+
self.drop = nn.Dropout(p_dropout)
|
23 |
+
self.act1 = nn.GELU(approximate="tanh")
|
24 |
+
|
25 |
+
def forward(self, x, x_mask):
|
26 |
+
x = self.conv_1(x * x_mask)
|
27 |
+
x = self.act1(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.conv_2(x * x_mask)
|
30 |
+
return x * x_mask
|
31 |
+
|
32 |
+
class MultiHeadAttention(nn.Module):
|
33 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
|
34 |
+
super().__init__()
|
35 |
+
assert channels % n_heads == 0
|
36 |
+
|
37 |
+
self.channels = channels
|
38 |
+
self.out_channels = out_channels
|
39 |
+
self.n_heads = n_heads
|
40 |
+
self.p_dropout = p_dropout
|
41 |
+
|
42 |
+
self.k_channels = channels // n_heads
|
43 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
44 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
45 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
46 |
+
|
47 |
+
# from https://nn.labml.ai/transformers/rope/index.html
|
48 |
+
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
49 |
+
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
50 |
+
|
51 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
52 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
55 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
56 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
57 |
+
|
58 |
+
def forward(self, x, attn_mask=None):
|
59 |
+
q = self.conv_q(x)
|
60 |
+
k = self.conv_k(x)
|
61 |
+
v = self.conv_v(x)
|
62 |
+
|
63 |
+
x = self.attention(q, k, v, mask=attn_mask)
|
64 |
+
|
65 |
+
x = self.conv_o(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def attention(self, query, key, value, mask=None):
|
69 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
70 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
71 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
72 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
73 |
+
|
74 |
+
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
|
75 |
+
key = self.key_rotary_pe(key)
|
76 |
+
|
77 |
+
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
|
78 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
79 |
+
return output
|
80 |
+
|
81 |
+
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
|
82 |
+
class DiTConVBlock(nn.Module):
|
83 |
+
"""
|
84 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
85 |
+
"""
|
86 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
|
87 |
+
super().__init__()
|
88 |
+
self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
|
89 |
+
self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
|
90 |
+
self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
|
91 |
+
self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
|
92 |
+
self.adaLN_modulation = nn.Sequential(
|
93 |
+
nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c, x_mask):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
x : [batch_size, channel, time]
|
102 |
+
c : [batch_size, channel]
|
103 |
+
x_mask : [batch_size, 1, time]
|
104 |
+
return the same shape as x
|
105 |
+
"""
|
106 |
+
x = x * x_mask
|
107 |
+
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
|
108 |
+
# attn_mask = attn_mask.to(torch.bool)
|
109 |
+
|
110 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
|
111 |
+
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
|
112 |
+
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
|
113 |
+
|
114 |
+
# no condition version
|
115 |
+
# x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
|
116 |
+
# x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask)
|
117 |
+
return x
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def modulate(x, shift, scale):
|
121 |
+
return x * (1 + scale) + shift
|
122 |
+
|
123 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
124 |
+
"""
|
125 |
+
## RoPE module
|
126 |
+
|
127 |
+
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
128 |
+
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
129 |
+
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
130 |
+
by an angle depending on the position of the token.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, d: int, base: int = 10_000):
|
134 |
+
r"""
|
135 |
+
* `d` is the number of features $d$
|
136 |
+
* `base` is the constant used for calculating $\Theta$
|
137 |
+
"""
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.base = base
|
141 |
+
self.d = int(d)
|
142 |
+
self.cos_cached = None
|
143 |
+
self.sin_cached = None
|
144 |
+
|
145 |
+
def _build_cache(self, x: torch.Tensor):
|
146 |
+
r"""
|
147 |
+
Cache $\cos$ and $\sin$ values
|
148 |
+
"""
|
149 |
+
# Return if cache is already built
|
150 |
+
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
151 |
+
return
|
152 |
+
|
153 |
+
# Get sequence length
|
154 |
+
seq_len = x.shape[0]
|
155 |
+
|
156 |
+
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
157 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
158 |
+
|
159 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
160 |
+
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
161 |
+
|
162 |
+
# Calculate the product of position index and $\theta_i$
|
163 |
+
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
164 |
+
|
165 |
+
# Concatenate so that for row $m$ we have
|
166 |
+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
167 |
+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
168 |
+
|
169 |
+
# Cache them
|
170 |
+
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
171 |
+
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
172 |
+
|
173 |
+
def _neg_half(self, x: torch.Tensor):
|
174 |
+
# $\frac{d}{2}$
|
175 |
+
d_2 = self.d // 2
|
176 |
+
|
177 |
+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
178 |
+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
"""
|
182 |
+
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
183 |
+
"""
|
184 |
+
# Cache $\cos$ and $\sin$ values
|
185 |
+
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
|
186 |
+
|
187 |
+
self._build_cache(x)
|
188 |
+
|
189 |
+
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
190 |
+
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
191 |
+
|
192 |
+
# Calculate
|
193 |
+
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
194 |
+
neg_half_x = self._neg_half(x_rope)
|
195 |
+
|
196 |
+
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
197 |
+
|
198 |
+
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
|
199 |
+
|
200 |
+
class Transpose(nn.Identity):
|
201 |
+
"""(N, T, D) -> (N, D, T)"""
|
202 |
+
|
203 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
204 |
+
return input.transpose(1, 2)
|
205 |
+
|
models/duration_predictor.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
|
5 |
+
class DurationPredictor(nn.Module):
|
6 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.in_channels = in_channels
|
10 |
+
self.filter_channels = filter_channels
|
11 |
+
self.kernel_size = kernel_size
|
12 |
+
self.p_dropout = p_dropout
|
13 |
+
self.gin_channels = gin_channels
|
14 |
+
|
15 |
+
self.drop = nn.Dropout(p_dropout)
|
16 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
17 |
+
self.norm_1 = nn.LayerNorm(filter_channels)
|
18 |
+
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
19 |
+
self.norm_2 = nn.LayerNorm(filter_channels)
|
20 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
21 |
+
|
22 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
23 |
+
|
24 |
+
def forward(self, x, x_mask, g):
|
25 |
+
x = x.detach()
|
26 |
+
x = x + self.cond(g.unsqueeze(2).detach())
|
27 |
+
x = self.conv_1(x * x_mask)
|
28 |
+
x = torch.relu(x)
|
29 |
+
x = self.norm_1(x.transpose(1,2)).transpose(1,2)
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.conv_2(x * x_mask)
|
32 |
+
x = torch.relu(x)
|
33 |
+
x = self.norm_2(x.transpose(1,2)).transpose(1,2)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.proj(x * x_mask)
|
36 |
+
return x * x_mask
|
37 |
+
|
38 |
+
def duration_loss(logw, logw_, lengths):
|
39 |
+
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
40 |
+
return loss
|
models/estimator.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.dit import DiTConVBlock
|
8 |
+
|
9 |
+
class DitWrapper(nn.Module):
|
10 |
+
""" add FiLM layer to condition time embedding to DiT """
|
11 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
|
12 |
+
super().__init__()
|
13 |
+
self.time_fusion = FiLMLayer(hidden_channels, time_channels)
|
14 |
+
self.conv1 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
15 |
+
self.conv2 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
16 |
+
self.conv3 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
17 |
+
self.block = DiTConVBlock(hidden_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels)
|
18 |
+
|
19 |
+
def forward(self, x, c, t, x_mask):
|
20 |
+
x = self.time_fusion(x, t) * x_mask
|
21 |
+
x = self.conv1(x, c, x_mask)
|
22 |
+
x = self.conv2(x, c, x_mask)
|
23 |
+
x = self.conv3(x, c, x_mask)
|
24 |
+
x = self.block(x, c, x_mask)
|
25 |
+
return x
|
26 |
+
|
27 |
+
class FiLMLayer(nn.Module):
|
28 |
+
"""
|
29 |
+
Feature-wise Linear Modulation (FiLM) layer
|
30 |
+
Reference: https://arxiv.org/abs/1709.07871
|
31 |
+
"""
|
32 |
+
def __init__(self, in_channels, cond_channels):
|
33 |
+
|
34 |
+
super(FiLMLayer, self).__init__()
|
35 |
+
self.in_channels = in_channels
|
36 |
+
self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
|
37 |
+
|
38 |
+
def forward(self, x, c):
|
39 |
+
gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
|
40 |
+
return gamma * x + beta
|
41 |
+
|
42 |
+
class ConvNeXtBlock(nn.Module):
|
43 |
+
def __init__(self, in_channels, filter_channels, gin_channels):
|
44 |
+
super().__init__()
|
45 |
+
self.dwconv = nn.Conv1d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
|
46 |
+
self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels)
|
47 |
+
self.pwconv = nn.Sequential(nn.Linear(in_channels, filter_channels),
|
48 |
+
nn.GELU(),
|
49 |
+
nn.Linear(filter_channels, in_channels))
|
50 |
+
|
51 |
+
def forward(self, x, c, x_mask) -> torch.Tensor:
|
52 |
+
residual = x
|
53 |
+
x = self.dwconv(x) * x_mask
|
54 |
+
x = self.norm(x.transpose(1, 2), c)
|
55 |
+
x = self.pwconv(x).transpose(1, 2)
|
56 |
+
x = residual + x
|
57 |
+
return x * x_mask
|
58 |
+
|
59 |
+
class StyleAdaptiveLayerNorm(nn.Module):
|
60 |
+
def __init__(self, in_channels, cond_channels):
|
61 |
+
"""
|
62 |
+
Style Adaptive Layer Normalization (SALN) module.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
in_channels: The number of channels in the input feature maps.
|
66 |
+
cond_channels: The number of channels in the conditioning input.
|
67 |
+
"""
|
68 |
+
super(StyleAdaptiveLayerNorm, self).__init__()
|
69 |
+
self.in_channels = in_channels
|
70 |
+
|
71 |
+
self.saln = nn.Linear(cond_channels, in_channels * 2, 1)
|
72 |
+
self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
73 |
+
|
74 |
+
self.reset_parameters()
|
75 |
+
|
76 |
+
def reset_parameters(self):
|
77 |
+
nn.init.constant_(self.saln.bias.data[:self.in_channels], 1)
|
78 |
+
nn.init.constant_(self.saln.bias.data[self.in_channels:], 0)
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1)
|
82 |
+
return gamma * self.norm(x) + beta
|
83 |
+
|
84 |
+
|
85 |
+
class SinusoidalPosEmb(nn.Module):
|
86 |
+
def __init__(self, dim):
|
87 |
+
super().__init__()
|
88 |
+
self.dim = dim
|
89 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
90 |
+
|
91 |
+
def forward(self, x, scale=1000):
|
92 |
+
if x.ndim < 1:
|
93 |
+
x = x.unsqueeze(0)
|
94 |
+
half_dim = self.dim // 2
|
95 |
+
emb = math.log(10000) / (half_dim - 1)
|
96 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
|
97 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
98 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
99 |
+
return emb
|
100 |
+
|
101 |
+
class TimestepEmbedding(nn.Module):
|
102 |
+
def __init__(self, in_channels, out_channels, filter_channels):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.layer = nn.Sequential(
|
106 |
+
nn.Linear(in_channels, filter_channels),
|
107 |
+
nn.SiLU(inplace=True),
|
108 |
+
nn.Linear(filter_channels, out_channels)
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return self.layer(x)
|
113 |
+
|
114 |
+
# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
|
115 |
+
class Decoder(nn.Module):
|
116 |
+
def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0):
|
117 |
+
super().__init__()
|
118 |
+
self.hidden_channels = hidden_channels
|
119 |
+
self.out_channels = out_channels
|
120 |
+
self.filter_channels = filter_channels
|
121 |
+
|
122 |
+
self.time_embeddings = SinusoidalPosEmb(hidden_channels)
|
123 |
+
self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
|
124 |
+
|
125 |
+
|
126 |
+
self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
|
127 |
+
self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
128 |
+
|
129 |
+
self.initialize_weights()
|
130 |
+
|
131 |
+
def initialize_weights(self):
|
132 |
+
for block in self.blocks:
|
133 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
|
134 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
|
135 |
+
|
136 |
+
def forward(self, x, mask, mu, t, c):
|
137 |
+
"""Forward pass of the UNet1DConditional model.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
141 |
+
mask (_type_): shape (batch_size, 1, time)
|
142 |
+
t (_type_): shape (batch_size)
|
143 |
+
c (_type_): shape (batch_size, gin_channels)
|
144 |
+
|
145 |
+
Raises:
|
146 |
+
ValueError: _description_
|
147 |
+
ValueError: _description_
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
_type_: _description_
|
151 |
+
"""
|
152 |
+
|
153 |
+
t = self.time_mlp(self.time_embeddings(t))
|
154 |
+
x = torch.cat((x, mu), dim=1)
|
155 |
+
|
156 |
+
for block in self.blocks:
|
157 |
+
x = block(x, c, t, mask)
|
158 |
+
|
159 |
+
output = self.final_proj(x * mask)
|
160 |
+
|
161 |
+
return output * mask
|
models/flow_matching.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from models.estimator import Decoder
|
6 |
+
|
7 |
+
# copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
|
8 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
9 |
+
if max_length is None:
|
10 |
+
max_length = length.max()
|
11 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
12 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
13 |
+
|
14 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
|
15 |
+
class CFMDecoder(torch.nn.Module):
|
16 |
+
def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.hidden_channels = hidden_channels
|
19 |
+
self.out_channels = out_channels
|
20 |
+
self.filter_channels = filter_channels
|
21 |
+
self.gin_channels = gin_channels
|
22 |
+
self.sigma_min = 1e-4
|
23 |
+
|
24 |
+
self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
|
25 |
+
|
26 |
+
@torch.inference_mode()
|
27 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None):
|
28 |
+
"""Forward diffusion
|
29 |
+
|
30 |
+
Args:
|
31 |
+
mu (torch.Tensor): output of encoder
|
32 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
33 |
+
mask (torch.Tensor): output_mask
|
34 |
+
shape: (batch_size, 1, mel_timesteps)
|
35 |
+
n_timesteps (int): number of diffusion steps
|
36 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
37 |
+
c (torch.Tensor, optional): shape: (batch_size, gin_channels)
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
sample: generated mel-spectrogram
|
41 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
42 |
+
"""
|
43 |
+
z = torch.randn_like(mu) * temperature
|
44 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
45 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c)
|
46 |
+
|
47 |
+
def solve_euler(self, x, t_span, mu, mask, c):
|
48 |
+
"""
|
49 |
+
Fixed euler solver for ODEs.
|
50 |
+
Args:
|
51 |
+
x (torch.Tensor): random noise
|
52 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
53 |
+
shape: (n_timesteps + 1,)
|
54 |
+
mu (torch.Tensor): output of encoder
|
55 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
56 |
+
mask (torch.Tensor): output_mask
|
57 |
+
shape: (batch_size, 1, mel_timesteps)
|
58 |
+
c (torch.Tensor, optional): speaker condition.
|
59 |
+
shape: (batch_size, gin_channels)
|
60 |
+
"""
|
61 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
62 |
+
|
63 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
64 |
+
# Or in future might add like a return_all_steps flag
|
65 |
+
sol = []
|
66 |
+
|
67 |
+
for step in range(1, len(t_span)):
|
68 |
+
dphi_dt = self.estimator(x, mask, mu, t, c)
|
69 |
+
|
70 |
+
x = x + dt * dphi_dt
|
71 |
+
t = t + dt
|
72 |
+
sol.append(x)
|
73 |
+
if step < len(t_span) - 1:
|
74 |
+
dt = t_span[step + 1] - t
|
75 |
+
|
76 |
+
return sol[-1]
|
77 |
+
|
78 |
+
def compute_loss(self, x1, mask, mu, c):
|
79 |
+
"""Computes diffusion loss
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x1 (torch.Tensor): Target
|
83 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
84 |
+
mask (torch.Tensor): target mask
|
85 |
+
shape: (batch_size, 1, mel_timesteps)
|
86 |
+
mu (torch.Tensor): output of encoder
|
87 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
88 |
+
c (torch.Tensor, optional): speaker condition.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
loss: conditional flow matching loss
|
92 |
+
y: conditional flow
|
93 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
94 |
+
"""
|
95 |
+
b, _, t = mu.shape
|
96 |
+
|
97 |
+
# random timestep
|
98 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
99 |
+
# sample noise p(x_0)
|
100 |
+
z = torch.randn_like(x1)
|
101 |
+
|
102 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
103 |
+
u = x1 - (1 - self.sigma_min) * z
|
104 |
+
|
105 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), u, reduction="sum") / (
|
106 |
+
torch.sum(mask) * u.shape[1]
|
107 |
+
)
|
108 |
+
return loss, y
|
models/model.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import monotonic_align
|
6 |
+
from models.text_encoder import TextEncoder
|
7 |
+
from models.flow_matching import CFMDecoder
|
8 |
+
from models.reference_encoder import MelStyleEncoder
|
9 |
+
from models.duration_predictor import DurationPredictor, duration_loss
|
10 |
+
|
11 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
12 |
+
if max_length is None:
|
13 |
+
max_length = length.max()
|
14 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
15 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
16 |
+
|
17 |
+
def convert_pad_shape(pad_shape):
|
18 |
+
inverted_shape = pad_shape[::-1]
|
19 |
+
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
20 |
+
return pad_shape
|
21 |
+
|
22 |
+
def generate_path(duration, mask):
|
23 |
+
device = duration.device
|
24 |
+
|
25 |
+
b, t_x, t_y = mask.shape
|
26 |
+
cum_duration = torch.cumsum(duration, 1)
|
27 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
28 |
+
|
29 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
30 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
31 |
+
path = path.view(b, t_x, t_y)
|
32 |
+
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
33 |
+
path = path * mask
|
34 |
+
return path
|
35 |
+
|
36 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
|
37 |
+
class StableTTS(nn.Module):
|
38 |
+
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.n_vocab = n_vocab
|
42 |
+
self.mel_channels = mel_channels
|
43 |
+
|
44 |
+
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
|
45 |
+
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3)
|
46 |
+
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels)
|
47 |
+
self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
|
48 |
+
|
49 |
+
@torch.inference_mode()
|
50 |
+
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0):
|
51 |
+
"""
|
52 |
+
Generates mel-spectrogram from text. Returns:
|
53 |
+
1. encoder outputs
|
54 |
+
2. decoder outputs
|
55 |
+
3. generated alignment
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
59 |
+
shape: (batch_size, max_text_length)
|
60 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
61 |
+
shape: (batch_size,)
|
62 |
+
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
63 |
+
temperature (float, optional): controls variance of terminal distribution.
|
64 |
+
y (torch.Tensor): mel spectrogram of reference audio
|
65 |
+
shape: (batch_size, mel_channels, time)
|
66 |
+
length_scale (float, optional): controls speech pace.
|
67 |
+
Increase value to slow down generated speech and vice versa.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
dict: {
|
71 |
+
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
72 |
+
# Average mel spectrogram generated by the encoder
|
73 |
+
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
74 |
+
# Refined mel spectrogram improved by the CFM
|
75 |
+
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
76 |
+
# Alignment map between text and mel spectrogram
|
77 |
+
"""
|
78 |
+
|
79 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
80 |
+
c = self.ref_encoder(y, None)
|
81 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
82 |
+
logw = self.dp(x, x_mask, c)
|
83 |
+
|
84 |
+
w = torch.exp(logw) * x_mask
|
85 |
+
w_ceil = torch.ceil(w) * length_scale
|
86 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
87 |
+
y_max_length = y_lengths.max()
|
88 |
+
|
89 |
+
# Using obtained durations `w` construct alignment map `attn`
|
90 |
+
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
|
91 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
92 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
93 |
+
|
94 |
+
# Align encoded text and get mu_y
|
95 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
96 |
+
mu_y = mu_y.transpose(1, 2)
|
97 |
+
encoder_outputs = mu_y[:, :, :y_max_length]
|
98 |
+
|
99 |
+
# Generate sample tracing the probability flow
|
100 |
+
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c)
|
101 |
+
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
102 |
+
|
103 |
+
|
104 |
+
return {
|
105 |
+
"encoder_outputs": encoder_outputs,
|
106 |
+
"decoder_outputs": decoder_outputs,
|
107 |
+
"attn": attn[:, :, :y_max_length],
|
108 |
+
}
|
109 |
+
|
110 |
+
def forward(self, x, x_lengths, y, y_lengths):
|
111 |
+
"""
|
112 |
+
Computes 3 losses:
|
113 |
+
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
114 |
+
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
115 |
+
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
119 |
+
shape: (batch_size, max_text_length)
|
120 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
121 |
+
shape: (batch_size,)
|
122 |
+
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
123 |
+
shape: (batch_size, n_feats, max_mel_length)
|
124 |
+
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
125 |
+
shape: (batch_size,)
|
126 |
+
"""
|
127 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
128 |
+
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
|
129 |
+
c = self.ref_encoder(y, y_mask)
|
130 |
+
|
131 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
132 |
+
logw = self.dp(x, x_mask, c)
|
133 |
+
|
134 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
135 |
+
|
136 |
+
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
137 |
+
|
138 |
+
# I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS
|
139 |
+
# so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works
|
140 |
+
# Welcome everyone to solve this problem QAQ
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
# const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
144 |
+
# const = -0.5 * math.log(2 * math.pi) * self.mel_channels
|
145 |
+
# factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
146 |
+
# y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
147 |
+
# y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
148 |
+
# mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
149 |
+
# log_prior = y_square - y_mu_double + mu_square + const
|
150 |
+
|
151 |
+
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
|
152 |
+
# s_p_sq_r = torch.exp(-2 * logx)
|
153 |
+
neg_cent1 = torch.sum(
|
154 |
+
-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
|
155 |
+
)
|
156 |
+
# neg_cent1 = torch.sum(
|
157 |
+
# -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
|
158 |
+
# ) # [b, 1, t_s]
|
159 |
+
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
|
160 |
+
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
|
161 |
+
neg_cent4 = torch.sum(
|
162 |
+
-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
|
163 |
+
)
|
164 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
165 |
+
|
166 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
167 |
+
|
168 |
+
attn = (
|
169 |
+
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
170 |
+
)
|
171 |
+
|
172 |
+
# attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
|
173 |
+
# attn = attn.detach()
|
174 |
+
|
175 |
+
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
176 |
+
# refered to as prior loss in the paper
|
177 |
+
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
|
178 |
+
# logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
179 |
+
dur_loss = duration_loss(logw, logw_, x_lengths)
|
180 |
+
|
181 |
+
|
182 |
+
# Align encoded text with mel-spectrogram and get mu_y segment
|
183 |
+
attn = attn.squeeze(1).transpose(1,2)
|
184 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
185 |
+
mu_y = mu_y.transpose(1, 2)
|
186 |
+
|
187 |
+
# Compute loss of the decoder
|
188 |
+
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c)
|
189 |
+
# diff_loss = torch.tensor([0], device=mu_y.device)
|
190 |
+
|
191 |
+
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
192 |
+
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
|
193 |
+
|
194 |
+
return dur_loss, diff_loss, prior_loss, attn
|
models/reference_encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Conv1dGLU(nn.Module):
|
5 |
+
"""
|
6 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
7 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
11 |
+
super(Conv1dGLU, self).__init__()
|
12 |
+
self.out_channels = out_channels
|
13 |
+
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
14 |
+
self.dropout = nn.Dropout(dropout)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
residual = x
|
18 |
+
x = self.conv1(x)
|
19 |
+
x1, x2 = torch.split(x, self.out_channels, dim=1)
|
20 |
+
x = x1 * torch.sigmoid(x2)
|
21 |
+
x = residual + self.dropout(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
# modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
|
25 |
+
class MelStyleEncoder(nn.Module):
|
26 |
+
"""MelStyleEncoder"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
n_mel_channels=80,
|
31 |
+
style_hidden=128,
|
32 |
+
style_vector_dim=256,
|
33 |
+
style_kernel_size=5,
|
34 |
+
style_head=2,
|
35 |
+
dropout=0.1,
|
36 |
+
):
|
37 |
+
super(MelStyleEncoder, self).__init__()
|
38 |
+
self.in_dim = n_mel_channels
|
39 |
+
self.hidden_dim = style_hidden
|
40 |
+
self.out_dim = style_vector_dim
|
41 |
+
self.kernel_size = style_kernel_size
|
42 |
+
self.n_head = style_head
|
43 |
+
self.dropout = dropout
|
44 |
+
|
45 |
+
self.spectral = nn.Sequential(
|
46 |
+
nn.Linear(self.in_dim, self.hidden_dim),
|
47 |
+
nn.Mish(inplace=True),
|
48 |
+
nn.Dropout(self.dropout),
|
49 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
50 |
+
nn.Mish(inplace=True),
|
51 |
+
nn.Dropout(self.dropout),
|
52 |
+
)
|
53 |
+
|
54 |
+
self.temporal = nn.Sequential(
|
55 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
56 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
57 |
+
)
|
58 |
+
|
59 |
+
self.slf_attn = nn.MultiheadAttention(
|
60 |
+
self.hidden_dim,
|
61 |
+
self.n_head,
|
62 |
+
self.dropout,
|
63 |
+
batch_first=True
|
64 |
+
)
|
65 |
+
|
66 |
+
self.fc = nn.Linear(self.hidden_dim, self.out_dim)
|
67 |
+
|
68 |
+
def temporal_avg_pool(self, x, mask=None):
|
69 |
+
if mask is None:
|
70 |
+
return torch.mean(x, dim=1)
|
71 |
+
else:
|
72 |
+
len_ = (~mask).sum(dim=1).unsqueeze(1).type_as(x)
|
73 |
+
return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / len_
|
74 |
+
|
75 |
+
def forward(self, x, x_mask=None):
|
76 |
+
x = x.transpose(1, 2)
|
77 |
+
|
78 |
+
# spectral
|
79 |
+
x = self.spectral(x)
|
80 |
+
# temporal
|
81 |
+
x = x.transpose(1, 2)
|
82 |
+
x = self.temporal(x)
|
83 |
+
x = x.transpose(1, 2)
|
84 |
+
# self-attention
|
85 |
+
if x_mask is not None:
|
86 |
+
x_mask = ~x_mask.squeeze(1).to(torch.bool)
|
87 |
+
x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask)
|
88 |
+
# fc
|
89 |
+
x = self.fc(x)
|
90 |
+
# temoral average pooling
|
91 |
+
w = self.temporal_avg_pool(x, mask=x_mask)
|
92 |
+
|
93 |
+
return w
|
models/text_encoder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.dit import DiTConVBlock
|
5 |
+
|
6 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
7 |
+
if max_length is None:
|
8 |
+
max_length = length.max()
|
9 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
10 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
11 |
+
|
12 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py
|
13 |
+
class TextEncoder(nn.Module):
|
14 |
+
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
15 |
+
super().__init__()
|
16 |
+
self.n_vocab = n_vocab
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.hidden_channels = hidden_channels
|
19 |
+
self.filter_channels = filter_channels
|
20 |
+
self.n_heads = n_heads
|
21 |
+
self.n_layers = n_layers
|
22 |
+
self.kernel_size = kernel_size
|
23 |
+
self.p_dropout = p_dropout
|
24 |
+
|
25 |
+
self.scale = self.hidden_channels ** 0.5
|
26 |
+
|
27 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
28 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
29 |
+
|
30 |
+
self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
|
31 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
32 |
+
|
33 |
+
self.initialize_weights()
|
34 |
+
|
35 |
+
def initialize_weights(self):
|
36 |
+
for block in self.encoder:
|
37 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
38 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
|
41 |
+
x = self.emb(x) * self.scale # [b, t, h]
|
42 |
+
x = x.transpose(1, -1) # [b, h, t]
|
43 |
+
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
|
44 |
+
|
45 |
+
for layer in self.encoder:
|
46 |
+
x = layer(x, c, x_mask)
|
47 |
+
mu_x = self.proj(x) * x_mask
|
48 |
+
|
49 |
+
return x, mu_x, x_mask
|
monotonic_align/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import zeros, int32, float32
|
2 |
+
from torch import from_numpy
|
3 |
+
|
4 |
+
from .core import maximum_path_jit
|
5 |
+
|
6 |
+
|
7 |
+
def maximum_path(neg_cent, mask):
|
8 |
+
device = neg_cent.device
|
9 |
+
dtype = neg_cent.dtype
|
10 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
11 |
+
path = zeros(neg_cent.shape, dtype=int32)
|
12 |
+
|
13 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
14 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
15 |
+
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
16 |
+
return from_numpy(path).to(device=device, dtype=dtype)
|
monotonic_align/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (732 Bytes). View file
|
|
monotonic_align/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (1.64 kB). View file
|
|
monotonic_align/__pycache__/core.cpython-310.pyc
ADDED
Binary file (985 Bytes). View file
|
|
monotonic_align/__pycache__/core.cpython-311.pyc
ADDED
Binary file (2 kB). View file
|
|
monotonic_align/core.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numba
|
2 |
+
|
3 |
+
|
4 |
+
@numba.jit(
|
5 |
+
numba.void(
|
6 |
+
numba.int32[:, :, ::1],
|
7 |
+
numba.float32[:, :, ::1],
|
8 |
+
numba.int32[::1],
|
9 |
+
numba.int32[::1],
|
10 |
+
),
|
11 |
+
nopython=True,
|
12 |
+
nogil=True,
|
13 |
+
)
|
14 |
+
def maximum_path_jit(paths, values, t_ys, t_xs):
|
15 |
+
b = paths.shape[0]
|
16 |
+
max_neg_val = -1e9
|
17 |
+
for i in range(int(b)):
|
18 |
+
path = paths[i]
|
19 |
+
value = values[i]
|
20 |
+
t_y = t_ys[i]
|
21 |
+
t_x = t_xs[i]
|
22 |
+
|
23 |
+
v_prev = v_cur = 0.0
|
24 |
+
index = t_x - 1
|
25 |
+
|
26 |
+
for y in range(t_y):
|
27 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
28 |
+
if x == y:
|
29 |
+
v_cur = max_neg_val
|
30 |
+
else:
|
31 |
+
v_cur = value[y - 1, x]
|
32 |
+
if x == 0:
|
33 |
+
if y == 0:
|
34 |
+
v_prev = 0.0
|
35 |
+
else:
|
36 |
+
v_prev = max_neg_val
|
37 |
+
else:
|
38 |
+
v_prev = value[y - 1, x - 1]
|
39 |
+
value[y, x] += max(v_prev, v_cur)
|
40 |
+
|
41 |
+
for y in range(t_y - 1, -1, -1):
|
42 |
+
path[y, index] = 1
|
43 |
+
if index != 0 and (
|
44 |
+
index == y or value[y - 1, index] < value[y - 1, index - 1]
|
45 |
+
):
|
46 |
+
index = index - 1
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
tensorboard
|
6 |
+
pypinyin
|
7 |
+
jieba
|
8 |
+
eng_to_ipa
|
9 |
+
unidecode
|
10 |
+
inflect
|
11 |
+
pyopenjtalk-prebuilt
|
12 |
+
numba
|
text/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
text/__init__.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from text import cleaners
|
3 |
+
from text.symbols import symbols
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
12 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
13 |
+
Args:
|
14 |
+
text: string to convert to a sequence
|
15 |
+
cleaner_names: names of the cleaner functions to run the text through
|
16 |
+
Returns:
|
17 |
+
List of integers corresponding to the symbols in the text
|
18 |
+
'''
|
19 |
+
sequence = []
|
20 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
21 |
+
clean_text = _clean_text(text, cleaner_names)
|
22 |
+
print(clean_text)
|
23 |
+
print(f" length:{len(clean_text)}")
|
24 |
+
for symbol in clean_text:
|
25 |
+
if symbol not in symbol_to_id.keys():
|
26 |
+
continue
|
27 |
+
symbol_id = symbol_to_id[symbol]
|
28 |
+
sequence += [symbol_id]
|
29 |
+
print(f" length:{len(sequence)}")
|
30 |
+
return sequence
|
31 |
+
|
32 |
+
|
33 |
+
def cleaned_text_to_sequence(cleaned_text):
|
34 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
Returns:
|
38 |
+
List of integers corresponding to the symbols in the text
|
39 |
+
'''
|
40 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
41 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
def cleaned_text_to_sequence_chinese(cleaned_text):
|
45 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
46 |
+
Args:
|
47 |
+
text: string to convert to a sequence
|
48 |
+
Returns:
|
49 |
+
List of integers corresponding to the symbols in the text
|
50 |
+
'''
|
51 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
52 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
|
53 |
+
return sequence
|
54 |
+
|
55 |
+
|
56 |
+
def sequence_to_text(sequence):
|
57 |
+
'''Converts a sequence of IDs back to a string'''
|
58 |
+
result = ''
|
59 |
+
for symbol_id in sequence:
|
60 |
+
s = _id_to_symbol[symbol_id]
|
61 |
+
result += s
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def _clean_text(text, cleaner_names):
|
66 |
+
for name in cleaner_names:
|
67 |
+
cleaner = getattr(cleaners, name)
|
68 |
+
if not cleaner:
|
69 |
+
raise Exception('Unknown cleaner: %s' % name)
|
70 |
+
text = cleaner(text)
|
71 |
+
return text
|
text/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
text/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (3.89 kB). View file
|
|
text/__pycache__/cleaners.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
text/__pycache__/cleaners.cpython-311.pyc
ADDED
Binary file (4.2 kB). View file
|
|
text/__pycache__/english.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|