mrfakename
commited on
Commit
·
1646c30
verified
·
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse files- LICENSE +21 -0
- README.md +13 -0
- app.py +207 -0
- ckpts/README.md +10 -0
- data/Emilia_ZH_EN_pinyin/vocab.txt +2545 -0
- data/librispeech_pc_test_clean_cross_sentence.lst +0 -0
- model/__init__.py +7 -0
- model/backbones/README.md +20 -0
- model/backbones/dit.py +158 -0
- model/backbones/mmdit.py +136 -0
- model/backbones/unett.py +201 -0
- model/cfm.py +273 -0
- model/dataset.py +242 -0
- model/ecapa_tdnn.py +268 -0
- model/modules.py +575 -0
- model/trainer.py +245 -0
- model/utils.py +545 -0
- packages.txt +1 -0
- requirements.txt +24 -0
- scripts/count_max_epoch.py +32 -0
- scripts/count_params_gflops.py +35 -0
- scripts/eval_librispeech_test_clean.py +67 -0
- scripts/eval_seedtts_testset.py +69 -0
- scripts/prepare_emilia.py +143 -0
- scripts/prepare_wenetspeech4tts.py +116 -0
- test_infer_batch.py +202 -0
- test_infer_batch.sh +13 -0
- test_infer_single.py +162 -0
- test_train.py +91 -0
- tests/ref_audio/test_en_1_ref_short.wav +0 -0
- tests/ref_audio/test_zh_1_ref_short.wav +0 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Yushen CHEN
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: E2/F5 TTS
|
3 |
+
emoji: 🗣️
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: true
|
9 |
+
short_description: 'E2-TTS & F5-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
|
10 |
+
sdk_version: 5.0.1
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import tempfile
|
8 |
+
from einops import rearrange
|
9 |
+
from ema_pytorch import EMA
|
10 |
+
from vocos import Vocos
|
11 |
+
from pydub import AudioSegment
|
12 |
+
from model import CFM, UNetT, DiT, MMDiT
|
13 |
+
from cached_path import cached_path
|
14 |
+
from model.utils import (
|
15 |
+
get_tokenizer,
|
16 |
+
convert_char_to_pinyin,
|
17 |
+
save_spectrogram,
|
18 |
+
)
|
19 |
+
from transformers import pipeline
|
20 |
+
import spaces
|
21 |
+
import librosa
|
22 |
+
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
pipe = pipeline(
|
26 |
+
"automatic-speech-recognition",
|
27 |
+
model="openai/whisper-large-v3-turbo",
|
28 |
+
torch_dtype=torch.float16,
|
29 |
+
device=device,
|
30 |
+
)
|
31 |
+
|
32 |
+
# --------------------- Settings -------------------- #
|
33 |
+
|
34 |
+
target_sample_rate = 24000
|
35 |
+
n_mel_channels = 100
|
36 |
+
hop_length = 256
|
37 |
+
target_rms = 0.1
|
38 |
+
nfe_step = 32 # 16, 32
|
39 |
+
cfg_strength = 2.0
|
40 |
+
ode_method = 'euler'
|
41 |
+
sway_sampling_coef = -1.0
|
42 |
+
speed = 1.0
|
43 |
+
# fix_duration = 27 # None or float (duration in seconds)
|
44 |
+
fix_duration = None
|
45 |
+
|
46 |
+
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
47 |
+
checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
|
48 |
+
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
49 |
+
model = CFM(
|
50 |
+
transformer=model_cls(
|
51 |
+
**model_cfg,
|
52 |
+
text_num_embeds=vocab_size,
|
53 |
+
mel_dim=n_mel_channels
|
54 |
+
),
|
55 |
+
mel_spec_kwargs=dict(
|
56 |
+
target_sample_rate=target_sample_rate,
|
57 |
+
n_mel_channels=n_mel_channels,
|
58 |
+
hop_length=hop_length,
|
59 |
+
),
|
60 |
+
odeint_kwargs=dict(
|
61 |
+
method=ode_method,
|
62 |
+
),
|
63 |
+
vocab_char_map=vocab_char_map,
|
64 |
+
).to(device)
|
65 |
+
|
66 |
+
ema_model = EMA(model, include_online_model=False).to(device)
|
67 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
68 |
+
ema_model.copy_params_from_ema_to_model()
|
69 |
+
|
70 |
+
return ema_model, model
|
71 |
+
|
72 |
+
# load models
|
73 |
+
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
74 |
+
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
75 |
+
|
76 |
+
F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
77 |
+
E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
78 |
+
|
79 |
+
@spaces.GPU
|
80 |
+
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
|
81 |
+
print(gen_text)
|
82 |
+
if len(gen_text) > 200:
|
83 |
+
raise gr.Error("Please keep your text under 200 chars.")
|
84 |
+
gr.Info("Converting audio...")
|
85 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
86 |
+
aseg = AudioSegment.from_file(ref_audio_orig)
|
87 |
+
audio_duration = len(aseg)
|
88 |
+
if audio_duration > 15000:
|
89 |
+
gr.Warning("Audio is over 15s, clipping to only first 15s.")
|
90 |
+
aseg = aseg[:15000]
|
91 |
+
aseg.export(f.name, format="wav")
|
92 |
+
ref_audio = f.name
|
93 |
+
if exp_name == "F5-TTS":
|
94 |
+
ema_model = F5TTS_ema_model
|
95 |
+
base_model = F5TTS_base_model
|
96 |
+
elif exp_name == "E2-TTS":
|
97 |
+
ema_model = E2TTS_ema_model
|
98 |
+
base_model = E2TTS_base_model
|
99 |
+
|
100 |
+
if not ref_text.strip():
|
101 |
+
gr.Info("No reference text provided, transcribing reference audio...")
|
102 |
+
ref_text = outputs = pipe(
|
103 |
+
ref_audio,
|
104 |
+
chunk_length_s=30,
|
105 |
+
batch_size=128,
|
106 |
+
generate_kwargs={"task": "transcribe"},
|
107 |
+
return_timestamps=False,
|
108 |
+
)['text'].strip()
|
109 |
+
gr.Info("Finished transcription")
|
110 |
+
else:
|
111 |
+
gr.Info("Using custom reference text...")
|
112 |
+
audio, sr = torchaudio.load(ref_audio)
|
113 |
+
|
114 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
115 |
+
if rms < target_rms:
|
116 |
+
audio = audio * target_rms / rms
|
117 |
+
if sr != target_sample_rate:
|
118 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
119 |
+
audio = resampler(audio)
|
120 |
+
audio = audio.to(device)
|
121 |
+
|
122 |
+
# Prepare the text
|
123 |
+
text_list = [ref_text + gen_text]
|
124 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
125 |
+
|
126 |
+
# Calculate duration
|
127 |
+
ref_audio_len = audio.shape[-1] // hop_length
|
128 |
+
# if fix_duration is not None:
|
129 |
+
# duration = int(fix_duration * target_sample_rate / hop_length)
|
130 |
+
# else:
|
131 |
+
zh_pause_punc = r"。,、;:?!"
|
132 |
+
ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
|
133 |
+
gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
|
134 |
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
135 |
+
|
136 |
+
# inference
|
137 |
+
gr.Info(f"Generating audio using {exp_name}")
|
138 |
+
with torch.inference_mode():
|
139 |
+
generated, _ = base_model.sample(
|
140 |
+
cond=audio,
|
141 |
+
text=final_text_list,
|
142 |
+
duration=duration,
|
143 |
+
steps=nfe_step,
|
144 |
+
cfg_strength=cfg_strength,
|
145 |
+
sway_sampling_coef=sway_sampling_coef,
|
146 |
+
)
|
147 |
+
|
148 |
+
generated = generated[:, ref_audio_len:, :]
|
149 |
+
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
150 |
+
gr.Info("Running vocoder")
|
151 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
152 |
+
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
153 |
+
if rms < target_rms:
|
154 |
+
generated_wave = generated_wave * rms / target_rms
|
155 |
+
|
156 |
+
# wav -> numpy
|
157 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
158 |
+
|
159 |
+
if remove_silence:
|
160 |
+
gr.Info("Removing audio silences")
|
161 |
+
non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
|
162 |
+
non_silent_wave = np.array([])
|
163 |
+
for interval in non_silent_intervals:
|
164 |
+
start, end = interval
|
165 |
+
non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
|
166 |
+
generated_wave = non_silent_wave
|
167 |
+
|
168 |
+
|
169 |
+
# spectogram
|
170 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
171 |
+
spectrogram_path = tmp_spectrogram.name
|
172 |
+
save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
|
173 |
+
|
174 |
+
return (target_sample_rate, generated_wave), spectrogram_path
|
175 |
+
|
176 |
+
with gr.Blocks() as app:
|
177 |
+
gr.Markdown("""
|
178 |
+
# E2/F5 TTS
|
179 |
+
|
180 |
+
This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
|
181 |
+
|
182 |
+
* [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
183 |
+
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
184 |
+
|
185 |
+
This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
|
186 |
+
|
187 |
+
The checkpoints support English and Chinese.
|
188 |
+
|
189 |
+
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
190 |
+
""")
|
191 |
+
|
192 |
+
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
193 |
+
gen_text_input = gr.Textbox(label="Text to Generate (max 200 chars.)", lines=4)
|
194 |
+
model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
195 |
+
generate_btn = gr.Button("Synthesize", variant="primary")
|
196 |
+
with gr.Accordion("Advanced Settings", open=False):
|
197 |
+
ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
|
198 |
+
remove_silence = gr.Checkbox(label="[EXPERIMENTAL] Remove Silences", info="The model tends to leave silences, we can manually remove silences if needed. This may produce strange results and is not guarenteed to work.")
|
199 |
+
|
200 |
+
audio_output = gr.Audio(label="Synthesized Audio")
|
201 |
+
spectrogram_output = gr.Image(label="Spectrogram")
|
202 |
+
|
203 |
+
generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, spectrogram_output])
|
204 |
+
gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
|
205 |
+
|
206 |
+
|
207 |
+
app.queue().launch()
|
ckpts/README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
3 |
+
|
4 |
+
```
|
5 |
+
ckpts/
|
6 |
+
E2TTS_Base/
|
7 |
+
model_1200000.pt
|
8 |
+
F5TTS_Base/
|
9 |
+
model_1200000.pt
|
10 |
+
```
|
data/Emilia_ZH_EN_pinyin/vocab.txt
ADDED
@@ -0,0 +1,2545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
!
|
3 |
+
"
|
4 |
+
#
|
5 |
+
$
|
6 |
+
%
|
7 |
+
&
|
8 |
+
'
|
9 |
+
(
|
10 |
+
)
|
11 |
+
*
|
12 |
+
+
|
13 |
+
,
|
14 |
+
-
|
15 |
+
.
|
16 |
+
/
|
17 |
+
0
|
18 |
+
1
|
19 |
+
2
|
20 |
+
3
|
21 |
+
4
|
22 |
+
5
|
23 |
+
6
|
24 |
+
7
|
25 |
+
8
|
26 |
+
9
|
27 |
+
:
|
28 |
+
;
|
29 |
+
=
|
30 |
+
>
|
31 |
+
?
|
32 |
+
@
|
33 |
+
A
|
34 |
+
B
|
35 |
+
C
|
36 |
+
D
|
37 |
+
E
|
38 |
+
F
|
39 |
+
G
|
40 |
+
H
|
41 |
+
I
|
42 |
+
J
|
43 |
+
K
|
44 |
+
L
|
45 |
+
M
|
46 |
+
N
|
47 |
+
O
|
48 |
+
P
|
49 |
+
Q
|
50 |
+
R
|
51 |
+
S
|
52 |
+
T
|
53 |
+
U
|
54 |
+
V
|
55 |
+
W
|
56 |
+
X
|
57 |
+
Y
|
58 |
+
Z
|
59 |
+
[
|
60 |
+
\
|
61 |
+
]
|
62 |
+
_
|
63 |
+
a
|
64 |
+
a1
|
65 |
+
ai1
|
66 |
+
ai2
|
67 |
+
ai3
|
68 |
+
ai4
|
69 |
+
an1
|
70 |
+
an3
|
71 |
+
an4
|
72 |
+
ang1
|
73 |
+
ang2
|
74 |
+
ang4
|
75 |
+
ao1
|
76 |
+
ao2
|
77 |
+
ao3
|
78 |
+
ao4
|
79 |
+
b
|
80 |
+
ba
|
81 |
+
ba1
|
82 |
+
ba2
|
83 |
+
ba3
|
84 |
+
ba4
|
85 |
+
bai1
|
86 |
+
bai2
|
87 |
+
bai3
|
88 |
+
bai4
|
89 |
+
ban1
|
90 |
+
ban2
|
91 |
+
ban3
|
92 |
+
ban4
|
93 |
+
bang1
|
94 |
+
bang2
|
95 |
+
bang3
|
96 |
+
bang4
|
97 |
+
bao1
|
98 |
+
bao2
|
99 |
+
bao3
|
100 |
+
bao4
|
101 |
+
bei
|
102 |
+
bei1
|
103 |
+
bei2
|
104 |
+
bei3
|
105 |
+
bei4
|
106 |
+
ben1
|
107 |
+
ben2
|
108 |
+
ben3
|
109 |
+
ben4
|
110 |
+
beng
|
111 |
+
beng1
|
112 |
+
beng2
|
113 |
+
beng3
|
114 |
+
beng4
|
115 |
+
bi1
|
116 |
+
bi2
|
117 |
+
bi3
|
118 |
+
bi4
|
119 |
+
bian1
|
120 |
+
bian2
|
121 |
+
bian3
|
122 |
+
bian4
|
123 |
+
biao1
|
124 |
+
biao2
|
125 |
+
biao3
|
126 |
+
bie1
|
127 |
+
bie2
|
128 |
+
bie3
|
129 |
+
bie4
|
130 |
+
bin1
|
131 |
+
bin4
|
132 |
+
bing1
|
133 |
+
bing2
|
134 |
+
bing3
|
135 |
+
bing4
|
136 |
+
bo
|
137 |
+
bo1
|
138 |
+
bo2
|
139 |
+
bo3
|
140 |
+
bo4
|
141 |
+
bu2
|
142 |
+
bu3
|
143 |
+
bu4
|
144 |
+
c
|
145 |
+
ca1
|
146 |
+
cai1
|
147 |
+
cai2
|
148 |
+
cai3
|
149 |
+
cai4
|
150 |
+
can1
|
151 |
+
can2
|
152 |
+
can3
|
153 |
+
can4
|
154 |
+
cang1
|
155 |
+
cang2
|
156 |
+
cao1
|
157 |
+
cao2
|
158 |
+
cao3
|
159 |
+
ce4
|
160 |
+
cen1
|
161 |
+
cen2
|
162 |
+
ceng1
|
163 |
+
ceng2
|
164 |
+
ceng4
|
165 |
+
cha1
|
166 |
+
cha2
|
167 |
+
cha3
|
168 |
+
cha4
|
169 |
+
chai1
|
170 |
+
chai2
|
171 |
+
chan1
|
172 |
+
chan2
|
173 |
+
chan3
|
174 |
+
chan4
|
175 |
+
chang1
|
176 |
+
chang2
|
177 |
+
chang3
|
178 |
+
chang4
|
179 |
+
chao1
|
180 |
+
chao2
|
181 |
+
chao3
|
182 |
+
che1
|
183 |
+
che2
|
184 |
+
che3
|
185 |
+
che4
|
186 |
+
chen1
|
187 |
+
chen2
|
188 |
+
chen3
|
189 |
+
chen4
|
190 |
+
cheng1
|
191 |
+
cheng2
|
192 |
+
cheng3
|
193 |
+
cheng4
|
194 |
+
chi1
|
195 |
+
chi2
|
196 |
+
chi3
|
197 |
+
chi4
|
198 |
+
chong1
|
199 |
+
chong2
|
200 |
+
chong3
|
201 |
+
chong4
|
202 |
+
chou1
|
203 |
+
chou2
|
204 |
+
chou3
|
205 |
+
chou4
|
206 |
+
chu1
|
207 |
+
chu2
|
208 |
+
chu3
|
209 |
+
chu4
|
210 |
+
chua1
|
211 |
+
chuai1
|
212 |
+
chuai2
|
213 |
+
chuai3
|
214 |
+
chuai4
|
215 |
+
chuan1
|
216 |
+
chuan2
|
217 |
+
chuan3
|
218 |
+
chuan4
|
219 |
+
chuang1
|
220 |
+
chuang2
|
221 |
+
chuang3
|
222 |
+
chuang4
|
223 |
+
chui1
|
224 |
+
chui2
|
225 |
+
chun1
|
226 |
+
chun2
|
227 |
+
chun3
|
228 |
+
chuo1
|
229 |
+
chuo4
|
230 |
+
ci1
|
231 |
+
ci2
|
232 |
+
ci3
|
233 |
+
ci4
|
234 |
+
cong1
|
235 |
+
cong2
|
236 |
+
cou4
|
237 |
+
cu1
|
238 |
+
cu4
|
239 |
+
cuan1
|
240 |
+
cuan2
|
241 |
+
cuan4
|
242 |
+
cui1
|
243 |
+
cui3
|
244 |
+
cui4
|
245 |
+
cun1
|
246 |
+
cun2
|
247 |
+
cun4
|
248 |
+
cuo1
|
249 |
+
cuo2
|
250 |
+
cuo4
|
251 |
+
d
|
252 |
+
da
|
253 |
+
da1
|
254 |
+
da2
|
255 |
+
da3
|
256 |
+
da4
|
257 |
+
dai1
|
258 |
+
dai2
|
259 |
+
dai3
|
260 |
+
dai4
|
261 |
+
dan1
|
262 |
+
dan2
|
263 |
+
dan3
|
264 |
+
dan4
|
265 |
+
dang1
|
266 |
+
dang2
|
267 |
+
dang3
|
268 |
+
dang4
|
269 |
+
dao1
|
270 |
+
dao2
|
271 |
+
dao3
|
272 |
+
dao4
|
273 |
+
de
|
274 |
+
de1
|
275 |
+
de2
|
276 |
+
dei3
|
277 |
+
den4
|
278 |
+
deng1
|
279 |
+
deng2
|
280 |
+
deng3
|
281 |
+
deng4
|
282 |
+
di1
|
283 |
+
di2
|
284 |
+
di3
|
285 |
+
di4
|
286 |
+
dia3
|
287 |
+
dian1
|
288 |
+
dian2
|
289 |
+
dian3
|
290 |
+
dian4
|
291 |
+
diao1
|
292 |
+
diao3
|
293 |
+
diao4
|
294 |
+
die1
|
295 |
+
die2
|
296 |
+
die4
|
297 |
+
ding1
|
298 |
+
ding2
|
299 |
+
ding3
|
300 |
+
ding4
|
301 |
+
diu1
|
302 |
+
dong1
|
303 |
+
dong3
|
304 |
+
dong4
|
305 |
+
dou1
|
306 |
+
dou2
|
307 |
+
dou3
|
308 |
+
dou4
|
309 |
+
du1
|
310 |
+
du2
|
311 |
+
du3
|
312 |
+
du4
|
313 |
+
duan1
|
314 |
+
duan2
|
315 |
+
duan3
|
316 |
+
duan4
|
317 |
+
dui1
|
318 |
+
dui4
|
319 |
+
dun1
|
320 |
+
dun3
|
321 |
+
dun4
|
322 |
+
duo1
|
323 |
+
duo2
|
324 |
+
duo3
|
325 |
+
duo4
|
326 |
+
e
|
327 |
+
e1
|
328 |
+
e2
|
329 |
+
e3
|
330 |
+
e4
|
331 |
+
ei2
|
332 |
+
en1
|
333 |
+
en4
|
334 |
+
er
|
335 |
+
er2
|
336 |
+
er3
|
337 |
+
er4
|
338 |
+
f
|
339 |
+
fa1
|
340 |
+
fa2
|
341 |
+
fa3
|
342 |
+
fa4
|
343 |
+
fan1
|
344 |
+
fan2
|
345 |
+
fan3
|
346 |
+
fan4
|
347 |
+
fang1
|
348 |
+
fang2
|
349 |
+
fang3
|
350 |
+
fang4
|
351 |
+
fei1
|
352 |
+
fei2
|
353 |
+
fei3
|
354 |
+
fei4
|
355 |
+
fen1
|
356 |
+
fen2
|
357 |
+
fen3
|
358 |
+
fen4
|
359 |
+
feng1
|
360 |
+
feng2
|
361 |
+
feng3
|
362 |
+
feng4
|
363 |
+
fo2
|
364 |
+
fou2
|
365 |
+
fou3
|
366 |
+
fu1
|
367 |
+
fu2
|
368 |
+
fu3
|
369 |
+
fu4
|
370 |
+
g
|
371 |
+
ga1
|
372 |
+
ga2
|
373 |
+
ga3
|
374 |
+
ga4
|
375 |
+
gai1
|
376 |
+
gai2
|
377 |
+
gai3
|
378 |
+
gai4
|
379 |
+
gan1
|
380 |
+
gan2
|
381 |
+
gan3
|
382 |
+
gan4
|
383 |
+
gang1
|
384 |
+
gang2
|
385 |
+
gang3
|
386 |
+
gang4
|
387 |
+
gao1
|
388 |
+
gao2
|
389 |
+
gao3
|
390 |
+
gao4
|
391 |
+
ge1
|
392 |
+
ge2
|
393 |
+
ge3
|
394 |
+
ge4
|
395 |
+
gei2
|
396 |
+
gei3
|
397 |
+
gen1
|
398 |
+
gen2
|
399 |
+
gen3
|
400 |
+
gen4
|
401 |
+
geng1
|
402 |
+
geng3
|
403 |
+
geng4
|
404 |
+
gong1
|
405 |
+
gong3
|
406 |
+
gong4
|
407 |
+
gou1
|
408 |
+
gou2
|
409 |
+
gou3
|
410 |
+
gou4
|
411 |
+
gu
|
412 |
+
gu1
|
413 |
+
gu2
|
414 |
+
gu3
|
415 |
+
gu4
|
416 |
+
gua1
|
417 |
+
gua2
|
418 |
+
gua3
|
419 |
+
gua4
|
420 |
+
guai1
|
421 |
+
guai2
|
422 |
+
guai3
|
423 |
+
guai4
|
424 |
+
guan1
|
425 |
+
guan2
|
426 |
+
guan3
|
427 |
+
guan4
|
428 |
+
guang1
|
429 |
+
guang2
|
430 |
+
guang3
|
431 |
+
guang4
|
432 |
+
gui1
|
433 |
+
gui2
|
434 |
+
gui3
|
435 |
+
gui4
|
436 |
+
gun3
|
437 |
+
gun4
|
438 |
+
guo1
|
439 |
+
guo2
|
440 |
+
guo3
|
441 |
+
guo4
|
442 |
+
h
|
443 |
+
ha1
|
444 |
+
ha2
|
445 |
+
ha3
|
446 |
+
hai1
|
447 |
+
hai2
|
448 |
+
hai3
|
449 |
+
hai4
|
450 |
+
han1
|
451 |
+
han2
|
452 |
+
han3
|
453 |
+
han4
|
454 |
+
hang1
|
455 |
+
hang2
|
456 |
+
hang4
|
457 |
+
hao1
|
458 |
+
hao2
|
459 |
+
hao3
|
460 |
+
hao4
|
461 |
+
he1
|
462 |
+
he2
|
463 |
+
he4
|
464 |
+
hei1
|
465 |
+
hen2
|
466 |
+
hen3
|
467 |
+
hen4
|
468 |
+
heng1
|
469 |
+
heng2
|
470 |
+
heng4
|
471 |
+
hong1
|
472 |
+
hong2
|
473 |
+
hong3
|
474 |
+
hong4
|
475 |
+
hou1
|
476 |
+
hou2
|
477 |
+
hou3
|
478 |
+
hou4
|
479 |
+
hu1
|
480 |
+
hu2
|
481 |
+
hu3
|
482 |
+
hu4
|
483 |
+
hua1
|
484 |
+
hua2
|
485 |
+
hua4
|
486 |
+
huai2
|
487 |
+
huai4
|
488 |
+
huan1
|
489 |
+
huan2
|
490 |
+
huan3
|
491 |
+
huan4
|
492 |
+
huang1
|
493 |
+
huang2
|
494 |
+
huang3
|
495 |
+
huang4
|
496 |
+
hui1
|
497 |
+
hui2
|
498 |
+
hui3
|
499 |
+
hui4
|
500 |
+
hun1
|
501 |
+
hun2
|
502 |
+
hun4
|
503 |
+
huo
|
504 |
+
huo1
|
505 |
+
huo2
|
506 |
+
huo3
|
507 |
+
huo4
|
508 |
+
i
|
509 |
+
j
|
510 |
+
ji1
|
511 |
+
ji2
|
512 |
+
ji3
|
513 |
+
ji4
|
514 |
+
jia
|
515 |
+
jia1
|
516 |
+
jia2
|
517 |
+
jia3
|
518 |
+
jia4
|
519 |
+
jian1
|
520 |
+
jian2
|
521 |
+
jian3
|
522 |
+
jian4
|
523 |
+
jiang1
|
524 |
+
jiang2
|
525 |
+
jiang3
|
526 |
+
jiang4
|
527 |
+
jiao1
|
528 |
+
jiao2
|
529 |
+
jiao3
|
530 |
+
jiao4
|
531 |
+
jie1
|
532 |
+
jie2
|
533 |
+
jie3
|
534 |
+
jie4
|
535 |
+
jin1
|
536 |
+
jin2
|
537 |
+
jin3
|
538 |
+
jin4
|
539 |
+
jing1
|
540 |
+
jing2
|
541 |
+
jing3
|
542 |
+
jing4
|
543 |
+
jiong3
|
544 |
+
jiu1
|
545 |
+
jiu2
|
546 |
+
jiu3
|
547 |
+
jiu4
|
548 |
+
ju1
|
549 |
+
ju2
|
550 |
+
ju3
|
551 |
+
ju4
|
552 |
+
juan1
|
553 |
+
juan2
|
554 |
+
juan3
|
555 |
+
juan4
|
556 |
+
jue1
|
557 |
+
jue2
|
558 |
+
jue4
|
559 |
+
jun1
|
560 |
+
jun4
|
561 |
+
k
|
562 |
+
ka1
|
563 |
+
ka2
|
564 |
+
ka3
|
565 |
+
kai1
|
566 |
+
kai2
|
567 |
+
kai3
|
568 |
+
kai4
|
569 |
+
kan1
|
570 |
+
kan2
|
571 |
+
kan3
|
572 |
+
kan4
|
573 |
+
kang1
|
574 |
+
kang2
|
575 |
+
kang4
|
576 |
+
kao1
|
577 |
+
kao2
|
578 |
+
kao3
|
579 |
+
kao4
|
580 |
+
ke1
|
581 |
+
ke2
|
582 |
+
ke3
|
583 |
+
ke4
|
584 |
+
ken3
|
585 |
+
keng1
|
586 |
+
kong1
|
587 |
+
kong3
|
588 |
+
kong4
|
589 |
+
kou1
|
590 |
+
kou2
|
591 |
+
kou3
|
592 |
+
kou4
|
593 |
+
ku1
|
594 |
+
ku2
|
595 |
+
ku3
|
596 |
+
ku4
|
597 |
+
kua1
|
598 |
+
kua3
|
599 |
+
kua4
|
600 |
+
kuai3
|
601 |
+
kuai4
|
602 |
+
kuan1
|
603 |
+
kuan2
|
604 |
+
kuan3
|
605 |
+
kuang1
|
606 |
+
kuang2
|
607 |
+
kuang4
|
608 |
+
kui1
|
609 |
+
kui2
|
610 |
+
kui3
|
611 |
+
kui4
|
612 |
+
kun1
|
613 |
+
kun3
|
614 |
+
kun4
|
615 |
+
kuo4
|
616 |
+
l
|
617 |
+
la
|
618 |
+
la1
|
619 |
+
la2
|
620 |
+
la3
|
621 |
+
la4
|
622 |
+
lai2
|
623 |
+
lai4
|
624 |
+
lan2
|
625 |
+
lan3
|
626 |
+
lan4
|
627 |
+
lang1
|
628 |
+
lang2
|
629 |
+
lang3
|
630 |
+
lang4
|
631 |
+
lao1
|
632 |
+
lao2
|
633 |
+
lao3
|
634 |
+
lao4
|
635 |
+
le
|
636 |
+
le1
|
637 |
+
le4
|
638 |
+
lei
|
639 |
+
lei1
|
640 |
+
lei2
|
641 |
+
lei3
|
642 |
+
lei4
|
643 |
+
leng1
|
644 |
+
leng2
|
645 |
+
leng3
|
646 |
+
leng4
|
647 |
+
li
|
648 |
+
li1
|
649 |
+
li2
|
650 |
+
li3
|
651 |
+
li4
|
652 |
+
lia3
|
653 |
+
lian2
|
654 |
+
lian3
|
655 |
+
lian4
|
656 |
+
liang2
|
657 |
+
liang3
|
658 |
+
liang4
|
659 |
+
liao1
|
660 |
+
liao2
|
661 |
+
liao3
|
662 |
+
liao4
|
663 |
+
lie1
|
664 |
+
lie2
|
665 |
+
lie3
|
666 |
+
lie4
|
667 |
+
lin1
|
668 |
+
lin2
|
669 |
+
lin3
|
670 |
+
lin4
|
671 |
+
ling2
|
672 |
+
ling3
|
673 |
+
ling4
|
674 |
+
liu1
|
675 |
+
liu2
|
676 |
+
liu3
|
677 |
+
liu4
|
678 |
+
long1
|
679 |
+
long2
|
680 |
+
long3
|
681 |
+
long4
|
682 |
+
lou1
|
683 |
+
lou2
|
684 |
+
lou3
|
685 |
+
lou4
|
686 |
+
lu1
|
687 |
+
lu2
|
688 |
+
lu3
|
689 |
+
lu4
|
690 |
+
luan2
|
691 |
+
luan3
|
692 |
+
luan4
|
693 |
+
lun1
|
694 |
+
lun2
|
695 |
+
lun4
|
696 |
+
luo1
|
697 |
+
luo2
|
698 |
+
luo3
|
699 |
+
luo4
|
700 |
+
lv2
|
701 |
+
lv3
|
702 |
+
lv4
|
703 |
+
lve3
|
704 |
+
lve4
|
705 |
+
m
|
706 |
+
ma
|
707 |
+
ma1
|
708 |
+
ma2
|
709 |
+
ma3
|
710 |
+
ma4
|
711 |
+
mai2
|
712 |
+
mai3
|
713 |
+
mai4
|
714 |
+
man1
|
715 |
+
man2
|
716 |
+
man3
|
717 |
+
man4
|
718 |
+
mang2
|
719 |
+
mang3
|
720 |
+
mao1
|
721 |
+
mao2
|
722 |
+
mao3
|
723 |
+
mao4
|
724 |
+
me
|
725 |
+
mei2
|
726 |
+
mei3
|
727 |
+
mei4
|
728 |
+
men
|
729 |
+
men1
|
730 |
+
men2
|
731 |
+
men4
|
732 |
+
meng
|
733 |
+
meng1
|
734 |
+
meng2
|
735 |
+
meng3
|
736 |
+
meng4
|
737 |
+
mi1
|
738 |
+
mi2
|
739 |
+
mi3
|
740 |
+
mi4
|
741 |
+
mian2
|
742 |
+
mian3
|
743 |
+
mian4
|
744 |
+
miao1
|
745 |
+
miao2
|
746 |
+
miao3
|
747 |
+
miao4
|
748 |
+
mie1
|
749 |
+
mie4
|
750 |
+
min2
|
751 |
+
min3
|
752 |
+
ming2
|
753 |
+
ming3
|
754 |
+
ming4
|
755 |
+
miu4
|
756 |
+
mo1
|
757 |
+
mo2
|
758 |
+
mo3
|
759 |
+
mo4
|
760 |
+
mou1
|
761 |
+
mou2
|
762 |
+
mou3
|
763 |
+
mu2
|
764 |
+
mu3
|
765 |
+
mu4
|
766 |
+
n
|
767 |
+
n2
|
768 |
+
na1
|
769 |
+
na2
|
770 |
+
na3
|
771 |
+
na4
|
772 |
+
nai2
|
773 |
+
nai3
|
774 |
+
nai4
|
775 |
+
nan1
|
776 |
+
nan2
|
777 |
+
nan3
|
778 |
+
nan4
|
779 |
+
nang1
|
780 |
+
nang2
|
781 |
+
nang3
|
782 |
+
nao1
|
783 |
+
nao2
|
784 |
+
nao3
|
785 |
+
nao4
|
786 |
+
ne
|
787 |
+
ne2
|
788 |
+
ne4
|
789 |
+
nei3
|
790 |
+
nei4
|
791 |
+
nen4
|
792 |
+
neng2
|
793 |
+
ni1
|
794 |
+
ni2
|
795 |
+
ni3
|
796 |
+
ni4
|
797 |
+
nian1
|
798 |
+
nian2
|
799 |
+
nian3
|
800 |
+
nian4
|
801 |
+
niang2
|
802 |
+
niang4
|
803 |
+
niao2
|
804 |
+
niao3
|
805 |
+
niao4
|
806 |
+
nie1
|
807 |
+
nie4
|
808 |
+
nin2
|
809 |
+
ning2
|
810 |
+
ning3
|
811 |
+
ning4
|
812 |
+
niu1
|
813 |
+
niu2
|
814 |
+
niu3
|
815 |
+
niu4
|
816 |
+
nong2
|
817 |
+
nong4
|
818 |
+
nou4
|
819 |
+
nu2
|
820 |
+
nu3
|
821 |
+
nu4
|
822 |
+
nuan3
|
823 |
+
nuo2
|
824 |
+
nuo4
|
825 |
+
nv2
|
826 |
+
nv3
|
827 |
+
nve4
|
828 |
+
o
|
829 |
+
o1
|
830 |
+
o2
|
831 |
+
ou1
|
832 |
+
ou2
|
833 |
+
ou3
|
834 |
+
ou4
|
835 |
+
p
|
836 |
+
pa1
|
837 |
+
pa2
|
838 |
+
pa4
|
839 |
+
pai1
|
840 |
+
pai2
|
841 |
+
pai3
|
842 |
+
pai4
|
843 |
+
pan1
|
844 |
+
pan2
|
845 |
+
pan4
|
846 |
+
pang1
|
847 |
+
pang2
|
848 |
+
pang4
|
849 |
+
pao1
|
850 |
+
pao2
|
851 |
+
pao3
|
852 |
+
pao4
|
853 |
+
pei1
|
854 |
+
pei2
|
855 |
+
pei4
|
856 |
+
pen1
|
857 |
+
pen2
|
858 |
+
pen4
|
859 |
+
peng1
|
860 |
+
peng2
|
861 |
+
peng3
|
862 |
+
peng4
|
863 |
+
pi1
|
864 |
+
pi2
|
865 |
+
pi3
|
866 |
+
pi4
|
867 |
+
pian1
|
868 |
+
pian2
|
869 |
+
pian4
|
870 |
+
piao1
|
871 |
+
piao2
|
872 |
+
piao3
|
873 |
+
piao4
|
874 |
+
pie1
|
875 |
+
pie2
|
876 |
+
pie3
|
877 |
+
pin1
|
878 |
+
pin2
|
879 |
+
pin3
|
880 |
+
pin4
|
881 |
+
ping1
|
882 |
+
ping2
|
883 |
+
po1
|
884 |
+
po2
|
885 |
+
po3
|
886 |
+
po4
|
887 |
+
pou1
|
888 |
+
pu1
|
889 |
+
pu2
|
890 |
+
pu3
|
891 |
+
pu4
|
892 |
+
q
|
893 |
+
qi1
|
894 |
+
qi2
|
895 |
+
qi3
|
896 |
+
qi4
|
897 |
+
qia1
|
898 |
+
qia3
|
899 |
+
qia4
|
900 |
+
qian1
|
901 |
+
qian2
|
902 |
+
qian3
|
903 |
+
qian4
|
904 |
+
qiang1
|
905 |
+
qiang2
|
906 |
+
qiang3
|
907 |
+
qiang4
|
908 |
+
qiao1
|
909 |
+
qiao2
|
910 |
+
qiao3
|
911 |
+
qiao4
|
912 |
+
qie1
|
913 |
+
qie2
|
914 |
+
qie3
|
915 |
+
qie4
|
916 |
+
qin1
|
917 |
+
qin2
|
918 |
+
qin3
|
919 |
+
qin4
|
920 |
+
qing1
|
921 |
+
qing2
|
922 |
+
qing3
|
923 |
+
qing4
|
924 |
+
qiong1
|
925 |
+
qiong2
|
926 |
+
qiu1
|
927 |
+
qiu2
|
928 |
+
qiu3
|
929 |
+
qu1
|
930 |
+
qu2
|
931 |
+
qu3
|
932 |
+
qu4
|
933 |
+
quan1
|
934 |
+
quan2
|
935 |
+
quan3
|
936 |
+
quan4
|
937 |
+
que1
|
938 |
+
que2
|
939 |
+
que4
|
940 |
+
qun2
|
941 |
+
r
|
942 |
+
ran2
|
943 |
+
ran3
|
944 |
+
rang1
|
945 |
+
rang2
|
946 |
+
rang3
|
947 |
+
rang4
|
948 |
+
rao2
|
949 |
+
rao3
|
950 |
+
rao4
|
951 |
+
re2
|
952 |
+
re3
|
953 |
+
re4
|
954 |
+
ren2
|
955 |
+
ren3
|
956 |
+
ren4
|
957 |
+
reng1
|
958 |
+
reng2
|
959 |
+
ri4
|
960 |
+
rong1
|
961 |
+
rong2
|
962 |
+
rong3
|
963 |
+
rou2
|
964 |
+
rou4
|
965 |
+
ru2
|
966 |
+
ru3
|
967 |
+
ru4
|
968 |
+
ruan2
|
969 |
+
ruan3
|
970 |
+
rui3
|
971 |
+
rui4
|
972 |
+
run4
|
973 |
+
ruo4
|
974 |
+
s
|
975 |
+
sa1
|
976 |
+
sa2
|
977 |
+
sa3
|
978 |
+
sa4
|
979 |
+
sai1
|
980 |
+
sai4
|
981 |
+
san1
|
982 |
+
san2
|
983 |
+
san3
|
984 |
+
san4
|
985 |
+
sang1
|
986 |
+
sang3
|
987 |
+
sang4
|
988 |
+
sao1
|
989 |
+
sao2
|
990 |
+
sao3
|
991 |
+
sao4
|
992 |
+
se4
|
993 |
+
sen1
|
994 |
+
seng1
|
995 |
+
sha1
|
996 |
+
sha2
|
997 |
+
sha3
|
998 |
+
sha4
|
999 |
+
shai1
|
1000 |
+
shai2
|
1001 |
+
shai3
|
1002 |
+
shai4
|
1003 |
+
shan1
|
1004 |
+
shan3
|
1005 |
+
shan4
|
1006 |
+
shang
|
1007 |
+
shang1
|
1008 |
+
shang3
|
1009 |
+
shang4
|
1010 |
+
shao1
|
1011 |
+
shao2
|
1012 |
+
shao3
|
1013 |
+
shao4
|
1014 |
+
she1
|
1015 |
+
she2
|
1016 |
+
she3
|
1017 |
+
she4
|
1018 |
+
shei2
|
1019 |
+
shen1
|
1020 |
+
shen2
|
1021 |
+
shen3
|
1022 |
+
shen4
|
1023 |
+
sheng1
|
1024 |
+
sheng2
|
1025 |
+
sheng3
|
1026 |
+
sheng4
|
1027 |
+
shi
|
1028 |
+
shi1
|
1029 |
+
shi2
|
1030 |
+
shi3
|
1031 |
+
shi4
|
1032 |
+
shou1
|
1033 |
+
shou2
|
1034 |
+
shou3
|
1035 |
+
shou4
|
1036 |
+
shu1
|
1037 |
+
shu2
|
1038 |
+
shu3
|
1039 |
+
shu4
|
1040 |
+
shua1
|
1041 |
+
shua2
|
1042 |
+
shua3
|
1043 |
+
shua4
|
1044 |
+
shuai1
|
1045 |
+
shuai3
|
1046 |
+
shuai4
|
1047 |
+
shuan1
|
1048 |
+
shuan4
|
1049 |
+
shuang1
|
1050 |
+
shuang3
|
1051 |
+
shui2
|
1052 |
+
shui3
|
1053 |
+
shui4
|
1054 |
+
shun3
|
1055 |
+
shun4
|
1056 |
+
shuo1
|
1057 |
+
shuo4
|
1058 |
+
si1
|
1059 |
+
si2
|
1060 |
+
si3
|
1061 |
+
si4
|
1062 |
+
song1
|
1063 |
+
song3
|
1064 |
+
song4
|
1065 |
+
sou1
|
1066 |
+
sou3
|
1067 |
+
sou4
|
1068 |
+
su1
|
1069 |
+
su2
|
1070 |
+
su4
|
1071 |
+
suan1
|
1072 |
+
suan4
|
1073 |
+
sui1
|
1074 |
+
sui2
|
1075 |
+
sui3
|
1076 |
+
sui4
|
1077 |
+
sun1
|
1078 |
+
sun3
|
1079 |
+
suo
|
1080 |
+
suo1
|
1081 |
+
suo2
|
1082 |
+
suo3
|
1083 |
+
t
|
1084 |
+
ta1
|
1085 |
+
ta2
|
1086 |
+
ta3
|
1087 |
+
ta4
|
1088 |
+
tai1
|
1089 |
+
tai2
|
1090 |
+
tai4
|
1091 |
+
tan1
|
1092 |
+
tan2
|
1093 |
+
tan3
|
1094 |
+
tan4
|
1095 |
+
tang1
|
1096 |
+
tang2
|
1097 |
+
tang3
|
1098 |
+
tang4
|
1099 |
+
tao1
|
1100 |
+
tao2
|
1101 |
+
tao3
|
1102 |
+
tao4
|
1103 |
+
te4
|
1104 |
+
teng2
|
1105 |
+
ti1
|
1106 |
+
ti2
|
1107 |
+
ti3
|
1108 |
+
ti4
|
1109 |
+
tian1
|
1110 |
+
tian2
|
1111 |
+
tian3
|
1112 |
+
tiao1
|
1113 |
+
tiao2
|
1114 |
+
tiao3
|
1115 |
+
tiao4
|
1116 |
+
tie1
|
1117 |
+
tie2
|
1118 |
+
tie3
|
1119 |
+
tie4
|
1120 |
+
ting1
|
1121 |
+
ting2
|
1122 |
+
ting3
|
1123 |
+
tong1
|
1124 |
+
tong2
|
1125 |
+
tong3
|
1126 |
+
tong4
|
1127 |
+
tou
|
1128 |
+
tou1
|
1129 |
+
tou2
|
1130 |
+
tou4
|
1131 |
+
tu1
|
1132 |
+
tu2
|
1133 |
+
tu3
|
1134 |
+
tu4
|
1135 |
+
tuan1
|
1136 |
+
tuan2
|
1137 |
+
tui1
|
1138 |
+
tui2
|
1139 |
+
tui3
|
1140 |
+
tui4
|
1141 |
+
tun1
|
1142 |
+
tun2
|
1143 |
+
tun4
|
1144 |
+
tuo1
|
1145 |
+
tuo2
|
1146 |
+
tuo3
|
1147 |
+
tuo4
|
1148 |
+
u
|
1149 |
+
v
|
1150 |
+
w
|
1151 |
+
wa
|
1152 |
+
wa1
|
1153 |
+
wa2
|
1154 |
+
wa3
|
1155 |
+
wa4
|
1156 |
+
wai1
|
1157 |
+
wai3
|
1158 |
+
wai4
|
1159 |
+
wan1
|
1160 |
+
wan2
|
1161 |
+
wan3
|
1162 |
+
wan4
|
1163 |
+
wang1
|
1164 |
+
wang2
|
1165 |
+
wang3
|
1166 |
+
wang4
|
1167 |
+
wei1
|
1168 |
+
wei2
|
1169 |
+
wei3
|
1170 |
+
wei4
|
1171 |
+
wen1
|
1172 |
+
wen2
|
1173 |
+
wen3
|
1174 |
+
wen4
|
1175 |
+
weng1
|
1176 |
+
weng4
|
1177 |
+
wo1
|
1178 |
+
wo2
|
1179 |
+
wo3
|
1180 |
+
wo4
|
1181 |
+
wu1
|
1182 |
+
wu2
|
1183 |
+
wu3
|
1184 |
+
wu4
|
1185 |
+
x
|
1186 |
+
xi1
|
1187 |
+
xi2
|
1188 |
+
xi3
|
1189 |
+
xi4
|
1190 |
+
xia1
|
1191 |
+
xia2
|
1192 |
+
xia4
|
1193 |
+
xian1
|
1194 |
+
xian2
|
1195 |
+
xian3
|
1196 |
+
xian4
|
1197 |
+
xiang1
|
1198 |
+
xiang2
|
1199 |
+
xiang3
|
1200 |
+
xiang4
|
1201 |
+
xiao1
|
1202 |
+
xiao2
|
1203 |
+
xiao3
|
1204 |
+
xiao4
|
1205 |
+
xie1
|
1206 |
+
xie2
|
1207 |
+
xie3
|
1208 |
+
xie4
|
1209 |
+
xin1
|
1210 |
+
xin2
|
1211 |
+
xin4
|
1212 |
+
xing1
|
1213 |
+
xing2
|
1214 |
+
xing3
|
1215 |
+
xing4
|
1216 |
+
xiong1
|
1217 |
+
xiong2
|
1218 |
+
xiu1
|
1219 |
+
xiu3
|
1220 |
+
xiu4
|
1221 |
+
xu
|
1222 |
+
xu1
|
1223 |
+
xu2
|
1224 |
+
xu3
|
1225 |
+
xu4
|
1226 |
+
xuan1
|
1227 |
+
xuan2
|
1228 |
+
xuan3
|
1229 |
+
xuan4
|
1230 |
+
xue1
|
1231 |
+
xue2
|
1232 |
+
xue3
|
1233 |
+
xue4
|
1234 |
+
xun1
|
1235 |
+
xun2
|
1236 |
+
xun4
|
1237 |
+
y
|
1238 |
+
ya
|
1239 |
+
ya1
|
1240 |
+
ya2
|
1241 |
+
ya3
|
1242 |
+
ya4
|
1243 |
+
yan1
|
1244 |
+
yan2
|
1245 |
+
yan3
|
1246 |
+
yan4
|
1247 |
+
yang1
|
1248 |
+
yang2
|
1249 |
+
yang3
|
1250 |
+
yang4
|
1251 |
+
yao1
|
1252 |
+
yao2
|
1253 |
+
yao3
|
1254 |
+
yao4
|
1255 |
+
ye1
|
1256 |
+
ye2
|
1257 |
+
ye3
|
1258 |
+
ye4
|
1259 |
+
yi
|
1260 |
+
yi1
|
1261 |
+
yi2
|
1262 |
+
yi3
|
1263 |
+
yi4
|
1264 |
+
yin1
|
1265 |
+
yin2
|
1266 |
+
yin3
|
1267 |
+
yin4
|
1268 |
+
ying1
|
1269 |
+
ying2
|
1270 |
+
ying3
|
1271 |
+
ying4
|
1272 |
+
yo1
|
1273 |
+
yong1
|
1274 |
+
yong2
|
1275 |
+
yong3
|
1276 |
+
yong4
|
1277 |
+
you1
|
1278 |
+
you2
|
1279 |
+
you3
|
1280 |
+
you4
|
1281 |
+
yu1
|
1282 |
+
yu2
|
1283 |
+
yu3
|
1284 |
+
yu4
|
1285 |
+
yuan1
|
1286 |
+
yuan2
|
1287 |
+
yuan3
|
1288 |
+
yuan4
|
1289 |
+
yue1
|
1290 |
+
yue4
|
1291 |
+
yun1
|
1292 |
+
yun2
|
1293 |
+
yun3
|
1294 |
+
yun4
|
1295 |
+
z
|
1296 |
+
za1
|
1297 |
+
za2
|
1298 |
+
za3
|
1299 |
+
zai1
|
1300 |
+
zai3
|
1301 |
+
zai4
|
1302 |
+
zan1
|
1303 |
+
zan2
|
1304 |
+
zan3
|
1305 |
+
zan4
|
1306 |
+
zang1
|
1307 |
+
zang4
|
1308 |
+
zao1
|
1309 |
+
zao2
|
1310 |
+
zao3
|
1311 |
+
zao4
|
1312 |
+
ze2
|
1313 |
+
ze4
|
1314 |
+
zei2
|
1315 |
+
zen3
|
1316 |
+
zeng1
|
1317 |
+
zeng4
|
1318 |
+
zha1
|
1319 |
+
zha2
|
1320 |
+
zha3
|
1321 |
+
zha4
|
1322 |
+
zhai1
|
1323 |
+
zhai2
|
1324 |
+
zhai3
|
1325 |
+
zhai4
|
1326 |
+
zhan1
|
1327 |
+
zhan2
|
1328 |
+
zhan3
|
1329 |
+
zhan4
|
1330 |
+
zhang1
|
1331 |
+
zhang2
|
1332 |
+
zhang3
|
1333 |
+
zhang4
|
1334 |
+
zhao1
|
1335 |
+
zhao2
|
1336 |
+
zhao3
|
1337 |
+
zhao4
|
1338 |
+
zhe
|
1339 |
+
zhe1
|
1340 |
+
zhe2
|
1341 |
+
zhe3
|
1342 |
+
zhe4
|
1343 |
+
zhen1
|
1344 |
+
zhen2
|
1345 |
+
zhen3
|
1346 |
+
zhen4
|
1347 |
+
zheng1
|
1348 |
+
zheng2
|
1349 |
+
zheng3
|
1350 |
+
zheng4
|
1351 |
+
zhi1
|
1352 |
+
zhi2
|
1353 |
+
zhi3
|
1354 |
+
zhi4
|
1355 |
+
zhong1
|
1356 |
+
zhong2
|
1357 |
+
zhong3
|
1358 |
+
zhong4
|
1359 |
+
zhou1
|
1360 |
+
zhou2
|
1361 |
+
zhou3
|
1362 |
+
zhou4
|
1363 |
+
zhu1
|
1364 |
+
zhu2
|
1365 |
+
zhu3
|
1366 |
+
zhu4
|
1367 |
+
zhua1
|
1368 |
+
zhua2
|
1369 |
+
zhua3
|
1370 |
+
zhuai1
|
1371 |
+
zhuai3
|
1372 |
+
zhuai4
|
1373 |
+
zhuan1
|
1374 |
+
zhuan2
|
1375 |
+
zhuan3
|
1376 |
+
zhuan4
|
1377 |
+
zhuang1
|
1378 |
+
zhuang4
|
1379 |
+
zhui1
|
1380 |
+
zhui4
|
1381 |
+
zhun1
|
1382 |
+
zhun2
|
1383 |
+
zhun3
|
1384 |
+
zhuo1
|
1385 |
+
zhuo2
|
1386 |
+
zi
|
1387 |
+
zi1
|
1388 |
+
zi2
|
1389 |
+
zi3
|
1390 |
+
zi4
|
1391 |
+
zong1
|
1392 |
+
zong2
|
1393 |
+
zong3
|
1394 |
+
zong4
|
1395 |
+
zou1
|
1396 |
+
zou2
|
1397 |
+
zou3
|
1398 |
+
zou4
|
1399 |
+
zu1
|
1400 |
+
zu2
|
1401 |
+
zu3
|
1402 |
+
zuan1
|
1403 |
+
zuan3
|
1404 |
+
zuan4
|
1405 |
+
zui2
|
1406 |
+
zui3
|
1407 |
+
zui4
|
1408 |
+
zun1
|
1409 |
+
zuo
|
1410 |
+
zuo1
|
1411 |
+
zuo2
|
1412 |
+
zuo3
|
1413 |
+
zuo4
|
1414 |
+
{
|
1415 |
+
~
|
1416 |
+
¡
|
1417 |
+
¢
|
1418 |
+
£
|
1419 |
+
¥
|
1420 |
+
§
|
1421 |
+
¨
|
1422 |
+
©
|
1423 |
+
«
|
1424 |
+
®
|
1425 |
+
¯
|
1426 |
+
°
|
1427 |
+
±
|
1428 |
+
²
|
1429 |
+
³
|
1430 |
+
´
|
1431 |
+
µ
|
1432 |
+
·
|
1433 |
+
¹
|
1434 |
+
º
|
1435 |
+
»
|
1436 |
+
¼
|
1437 |
+
½
|
1438 |
+
¾
|
1439 |
+
¿
|
1440 |
+
À
|
1441 |
+
Á
|
1442 |
+
Â
|
1443 |
+
Ã
|
1444 |
+
Ä
|
1445 |
+
Å
|
1446 |
+
Æ
|
1447 |
+
Ç
|
1448 |
+
È
|
1449 |
+
É
|
1450 |
+
Ê
|
1451 |
+
Í
|
1452 |
+
Î
|
1453 |
+
Ñ
|
1454 |
+
Ó
|
1455 |
+
Ö
|
1456 |
+
×
|
1457 |
+
Ø
|
1458 |
+
Ú
|
1459 |
+
Ü
|
1460 |
+
Ý
|
1461 |
+
Þ
|
1462 |
+
ß
|
1463 |
+
à
|
1464 |
+
á
|
1465 |
+
â
|
1466 |
+
ã
|
1467 |
+
ä
|
1468 |
+
å
|
1469 |
+
æ
|
1470 |
+
ç
|
1471 |
+
è
|
1472 |
+
é
|
1473 |
+
ê
|
1474 |
+
ë
|
1475 |
+
ì
|
1476 |
+
í
|
1477 |
+
î
|
1478 |
+
ï
|
1479 |
+
ð
|
1480 |
+
ñ
|
1481 |
+
ò
|
1482 |
+
ó
|
1483 |
+
ô
|
1484 |
+
õ
|
1485 |
+
ö
|
1486 |
+
ø
|
1487 |
+
ù
|
1488 |
+
ú
|
1489 |
+
û
|
1490 |
+
ü
|
1491 |
+
ý
|
1492 |
+
Ā
|
1493 |
+
ā
|
1494 |
+
ă
|
1495 |
+
ą
|
1496 |
+
ć
|
1497 |
+
Č
|
1498 |
+
č
|
1499 |
+
Đ
|
1500 |
+
đ
|
1501 |
+
ē
|
1502 |
+
ė
|
1503 |
+
ę
|
1504 |
+
ě
|
1505 |
+
ĝ
|
1506 |
+
ğ
|
1507 |
+
ħ
|
1508 |
+
ī
|
1509 |
+
į
|
1510 |
+
İ
|
1511 |
+
ı
|
1512 |
+
Ł
|
1513 |
+
ł
|
1514 |
+
ń
|
1515 |
+
ņ
|
1516 |
+
ň
|
1517 |
+
ŋ
|
1518 |
+
Ō
|
1519 |
+
ō
|
1520 |
+
ő
|
1521 |
+
œ
|
1522 |
+
ř
|
1523 |
+
Ś
|
1524 |
+
ś
|
1525 |
+
Ş
|
1526 |
+
ş
|
1527 |
+
Š
|
1528 |
+
š
|
1529 |
+
Ť
|
1530 |
+
ť
|
1531 |
+
ũ
|
1532 |
+
ū
|
1533 |
+
ź
|
1534 |
+
Ż
|
1535 |
+
ż
|
1536 |
+
Ž
|
1537 |
+
ž
|
1538 |
+
ơ
|
1539 |
+
ư
|
1540 |
+
ǎ
|
1541 |
+
ǐ
|
1542 |
+
ǒ
|
1543 |
+
ǔ
|
1544 |
+
ǚ
|
1545 |
+
ș
|
1546 |
+
ț
|
1547 |
+
ɑ
|
1548 |
+
ɔ
|
1549 |
+
ɕ
|
1550 |
+
ə
|
1551 |
+
ɛ
|
1552 |
+
ɜ
|
1553 |
+
ɡ
|
1554 |
+
ɣ
|
1555 |
+
ɪ
|
1556 |
+
ɫ
|
1557 |
+
ɴ
|
1558 |
+
ɹ
|
1559 |
+
ɾ
|
1560 |
+
ʃ
|
1561 |
+
ʊ
|
1562 |
+
ʌ
|
1563 |
+
ʒ
|
1564 |
+
ʔ
|
1565 |
+
ʰ
|
1566 |
+
ʷ
|
1567 |
+
ʻ
|
1568 |
+
ʾ
|
1569 |
+
ʿ
|
1570 |
+
ˈ
|
1571 |
+
ː
|
1572 |
+
˙
|
1573 |
+
˜
|
1574 |
+
ˢ
|
1575 |
+
́
|
1576 |
+
̅
|
1577 |
+
Α
|
1578 |
+
Β
|
1579 |
+
Δ
|
1580 |
+
Ε
|
1581 |
+
Θ
|
1582 |
+
Κ
|
1583 |
+
Λ
|
1584 |
+
Μ
|
1585 |
+
Ξ
|
1586 |
+
Π
|
1587 |
+
Σ
|
1588 |
+
Τ
|
1589 |
+
Φ
|
1590 |
+
Χ
|
1591 |
+
Ψ
|
1592 |
+
Ω
|
1593 |
+
ά
|
1594 |
+
έ
|
1595 |
+
ή
|
1596 |
+
ί
|
1597 |
+
α
|
1598 |
+
β
|
1599 |
+
γ
|
1600 |
+
δ
|
1601 |
+
ε
|
1602 |
+
ζ
|
1603 |
+
η
|
1604 |
+
θ
|
1605 |
+
ι
|
1606 |
+
κ
|
1607 |
+
λ
|
1608 |
+
μ
|
1609 |
+
ν
|
1610 |
+
ξ
|
1611 |
+
ο
|
1612 |
+
π
|
1613 |
+
ρ
|
1614 |
+
ς
|
1615 |
+
σ
|
1616 |
+
τ
|
1617 |
+
υ
|
1618 |
+
φ
|
1619 |
+
χ
|
1620 |
+
ψ
|
1621 |
+
ω
|
1622 |
+
ϊ
|
1623 |
+
ό
|
1624 |
+
ύ
|
1625 |
+
ώ
|
1626 |
+
ϕ
|
1627 |
+
ϵ
|
1628 |
+
Ё
|
1629 |
+
А
|
1630 |
+
Б
|
1631 |
+
В
|
1632 |
+
Г
|
1633 |
+
Д
|
1634 |
+
Е
|
1635 |
+
Ж
|
1636 |
+
З
|
1637 |
+
И
|
1638 |
+
Й
|
1639 |
+
К
|
1640 |
+
Л
|
1641 |
+
М
|
1642 |
+
Н
|
1643 |
+
О
|
1644 |
+
П
|
1645 |
+
Р
|
1646 |
+
С
|
1647 |
+
Т
|
1648 |
+
У
|
1649 |
+
Ф
|
1650 |
+
Х
|
1651 |
+
Ц
|
1652 |
+
Ч
|
1653 |
+
Ш
|
1654 |
+
Щ
|
1655 |
+
Ы
|
1656 |
+
Ь
|
1657 |
+
Э
|
1658 |
+
Ю
|
1659 |
+
Я
|
1660 |
+
а
|
1661 |
+
б
|
1662 |
+
в
|
1663 |
+
г
|
1664 |
+
д
|
1665 |
+
е
|
1666 |
+
ж
|
1667 |
+
з
|
1668 |
+
и
|
1669 |
+
й
|
1670 |
+
к
|
1671 |
+
л
|
1672 |
+
м
|
1673 |
+
н
|
1674 |
+
о
|
1675 |
+
п
|
1676 |
+
р
|
1677 |
+
с
|
1678 |
+
т
|
1679 |
+
у
|
1680 |
+
ф
|
1681 |
+
х
|
1682 |
+
ц
|
1683 |
+
ч
|
1684 |
+
ш
|
1685 |
+
щ
|
1686 |
+
ъ
|
1687 |
+
ы
|
1688 |
+
ь
|
1689 |
+
э
|
1690 |
+
ю
|
1691 |
+
я
|
1692 |
+
ё
|
1693 |
+
і
|
1694 |
+
ְ
|
1695 |
+
ִ
|
1696 |
+
ֵ
|
1697 |
+
ֶ
|
1698 |
+
ַ
|
1699 |
+
ָ
|
1700 |
+
ֹ
|
1701 |
+
ּ
|
1702 |
+
־
|
1703 |
+
ׁ
|
1704 |
+
א
|
1705 |
+
ב
|
1706 |
+
ג
|
1707 |
+
ד
|
1708 |
+
ה
|
1709 |
+
ו
|
1710 |
+
ז
|
1711 |
+
ח
|
1712 |
+
ט
|
1713 |
+
י
|
1714 |
+
כ
|
1715 |
+
ל
|
1716 |
+
ם
|
1717 |
+
מ
|
1718 |
+
ן
|
1719 |
+
נ
|
1720 |
+
ס
|
1721 |
+
ע
|
1722 |
+
פ
|
1723 |
+
ק
|
1724 |
+
ר
|
1725 |
+
ש
|
1726 |
+
ת
|
1727 |
+
أ
|
1728 |
+
ب
|
1729 |
+
ة
|
1730 |
+
ت
|
1731 |
+
ج
|
1732 |
+
ح
|
1733 |
+
د
|
1734 |
+
ر
|
1735 |
+
ز
|
1736 |
+
س
|
1737 |
+
ص
|
1738 |
+
ط
|
1739 |
+
ع
|
1740 |
+
ق
|
1741 |
+
ك
|
1742 |
+
ل
|
1743 |
+
م
|
1744 |
+
ن
|
1745 |
+
ه
|
1746 |
+
و
|
1747 |
+
ي
|
1748 |
+
َ
|
1749 |
+
ُ
|
1750 |
+
ِ
|
1751 |
+
ْ
|
1752 |
+
ก
|
1753 |
+
ข
|
1754 |
+
ง
|
1755 |
+
จ
|
1756 |
+
ต
|
1757 |
+
ท
|
1758 |
+
น
|
1759 |
+
ป
|
1760 |
+
ย
|
1761 |
+
ร
|
1762 |
+
ว
|
1763 |
+
ส
|
1764 |
+
ห
|
1765 |
+
อ
|
1766 |
+
ฮ
|
1767 |
+
ั
|
1768 |
+
า
|
1769 |
+
ี
|
1770 |
+
ึ
|
1771 |
+
โ
|
1772 |
+
ใ
|
1773 |
+
ไ
|
1774 |
+
่
|
1775 |
+
้
|
1776 |
+
์
|
1777 |
+
ḍ
|
1778 |
+
Ḥ
|
1779 |
+
ḥ
|
1780 |
+
ṁ
|
1781 |
+
ṃ
|
1782 |
+
ṅ
|
1783 |
+
ṇ
|
1784 |
+
Ṛ
|
1785 |
+
ṛ
|
1786 |
+
Ṣ
|
1787 |
+
ṣ
|
1788 |
+
Ṭ
|
1789 |
+
ṭ
|
1790 |
+
ạ
|
1791 |
+
ả
|
1792 |
+
Ấ
|
1793 |
+
ấ
|
1794 |
+
ầ
|
1795 |
+
ậ
|
1796 |
+
ắ
|
1797 |
+
ằ
|
1798 |
+
ẻ
|
1799 |
+
ẽ
|
1800 |
+
ế
|
1801 |
+
ề
|
1802 |
+
ể
|
1803 |
+
ễ
|
1804 |
+
ệ
|
1805 |
+
ị
|
1806 |
+
ọ
|
1807 |
+
ỏ
|
1808 |
+
ố
|
1809 |
+
ồ
|
1810 |
+
ộ
|
1811 |
+
ớ
|
1812 |
+
ờ
|
1813 |
+
ở
|
1814 |
+
ụ
|
1815 |
+
ủ
|
1816 |
+
ứ
|
1817 |
+
ữ
|
1818 |
+
ἀ
|
1819 |
+
ἁ
|
1820 |
+
Ἀ
|
1821 |
+
ἐ
|
1822 |
+
ἔ
|
1823 |
+
ἰ
|
1824 |
+
ἱ
|
1825 |
+
ὀ
|
1826 |
+
ὁ
|
1827 |
+
ὐ
|
1828 |
+
ὲ
|
1829 |
+
ὸ
|
1830 |
+
���
|
1831 |
+
᾽
|
1832 |
+
ῆ
|
1833 |
+
ῇ
|
1834 |
+
ῶ
|
1835 |
+
|
1836 |
+
‑
|
1837 |
+
‒
|
1838 |
+
–
|
1839 |
+
—
|
1840 |
+
―
|
1841 |
+
‖
|
1842 |
+
†
|
1843 |
+
‡
|
1844 |
+
•
|
1845 |
+
…
|
1846 |
+
‧
|
1847 |
+
|
1848 |
+
′
|
1849 |
+
″
|
1850 |
+
⁄
|
1851 |
+
|
1852 |
+
⁰
|
1853 |
+
⁴
|
1854 |
+
⁵
|
1855 |
+
⁶
|
1856 |
+
⁷
|
1857 |
+
⁸
|
1858 |
+
⁹
|
1859 |
+
₁
|
1860 |
+
₂
|
1861 |
+
₃
|
1862 |
+
€
|
1863 |
+
₱
|
1864 |
+
₹
|
1865 |
+
₽
|
1866 |
+
℃
|
1867 |
+
ℏ
|
1868 |
+
ℓ
|
1869 |
+
№
|
1870 |
+
ℝ
|
1871 |
+
™
|
1872 |
+
⅓
|
1873 |
+
⅔
|
1874 |
+
⅛
|
1875 |
+
→
|
1876 |
+
∂
|
1877 |
+
∈
|
1878 |
+
∑
|
1879 |
+
−
|
1880 |
+
∗
|
1881 |
+
√
|
1882 |
+
∞
|
1883 |
+
∫
|
1884 |
+
≈
|
1885 |
+
≠
|
1886 |
+
≡
|
1887 |
+
≤
|
1888 |
+
≥
|
1889 |
+
⋅
|
1890 |
+
⋯
|
1891 |
+
█
|
1892 |
+
♪
|
1893 |
+
⟨
|
1894 |
+
⟩
|
1895 |
+
、
|
1896 |
+
。
|
1897 |
+
《
|
1898 |
+
》
|
1899 |
+
「
|
1900 |
+
」
|
1901 |
+
【
|
1902 |
+
】
|
1903 |
+
あ
|
1904 |
+
う
|
1905 |
+
え
|
1906 |
+
お
|
1907 |
+
か
|
1908 |
+
が
|
1909 |
+
き
|
1910 |
+
ぎ
|
1911 |
+
く
|
1912 |
+
ぐ
|
1913 |
+
け
|
1914 |
+
げ
|
1915 |
+
こ
|
1916 |
+
ご
|
1917 |
+
さ
|
1918 |
+
し
|
1919 |
+
じ
|
1920 |
+
す
|
1921 |
+
ず
|
1922 |
+
せ
|
1923 |
+
ぜ
|
1924 |
+
そ
|
1925 |
+
ぞ
|
1926 |
+
た
|
1927 |
+
だ
|
1928 |
+
ち
|
1929 |
+
っ
|
1930 |
+
つ
|
1931 |
+
で
|
1932 |
+
と
|
1933 |
+
ど
|
1934 |
+
な
|
1935 |
+
に
|
1936 |
+
ね
|
1937 |
+
の
|
1938 |
+
は
|
1939 |
+
ば
|
1940 |
+
ひ
|
1941 |
+
ぶ
|
1942 |
+
へ
|
1943 |
+
べ
|
1944 |
+
ま
|
1945 |
+
み
|
1946 |
+
む
|
1947 |
+
め
|
1948 |
+
も
|
1949 |
+
ゃ
|
1950 |
+
や
|
1951 |
+
ゆ
|
1952 |
+
ょ
|
1953 |
+
よ
|
1954 |
+
ら
|
1955 |
+
り
|
1956 |
+
る
|
1957 |
+
れ
|
1958 |
+
ろ
|
1959 |
+
わ
|
1960 |
+
を
|
1961 |
+
ん
|
1962 |
+
ァ
|
1963 |
+
ア
|
1964 |
+
ィ
|
1965 |
+
イ
|
1966 |
+
ウ
|
1967 |
+
ェ
|
1968 |
+
エ
|
1969 |
+
オ
|
1970 |
+
カ
|
1971 |
+
ガ
|
1972 |
+
キ
|
1973 |
+
ク
|
1974 |
+
ケ
|
1975 |
+
ゲ
|
1976 |
+
コ
|
1977 |
+
ゴ
|
1978 |
+
サ
|
1979 |
+
ザ
|
1980 |
+
シ
|
1981 |
+
ジ
|
1982 |
+
ス
|
1983 |
+
ズ
|
1984 |
+
セ
|
1985 |
+
ゾ
|
1986 |
+
タ
|
1987 |
+
ダ
|
1988 |
+
チ
|
1989 |
+
ッ
|
1990 |
+
ツ
|
1991 |
+
テ
|
1992 |
+
デ
|
1993 |
+
ト
|
1994 |
+
ド
|
1995 |
+
ナ
|
1996 |
+
ニ
|
1997 |
+
ネ
|
1998 |
+
ノ
|
1999 |
+
バ
|
2000 |
+
パ
|
2001 |
+
ビ
|
2002 |
+
ピ
|
2003 |
+
フ
|
2004 |
+
プ
|
2005 |
+
ヘ
|
2006 |
+
ベ
|
2007 |
+
ペ
|
2008 |
+
ホ
|
2009 |
+
ボ
|
2010 |
+
ポ
|
2011 |
+
マ
|
2012 |
+
ミ
|
2013 |
+
ム
|
2014 |
+
メ
|
2015 |
+
モ
|
2016 |
+
ャ
|
2017 |
+
ヤ
|
2018 |
+
ュ
|
2019 |
+
ユ
|
2020 |
+
ョ
|
2021 |
+
ヨ
|
2022 |
+
ラ
|
2023 |
+
リ
|
2024 |
+
ル
|
2025 |
+
レ
|
2026 |
+
ロ
|
2027 |
+
ワ
|
2028 |
+
ン
|
2029 |
+
・
|
2030 |
+
ー
|
2031 |
+
ㄋ
|
2032 |
+
ㄍ
|
2033 |
+
ㄎ
|
2034 |
+
ㄏ
|
2035 |
+
ㄓ
|
2036 |
+
ㄕ
|
2037 |
+
ㄚ
|
2038 |
+
ㄜ
|
2039 |
+
ㄟ
|
2040 |
+
ㄤ
|
2041 |
+
ㄥ
|
2042 |
+
ㄧ
|
2043 |
+
ㄱ
|
2044 |
+
ㄴ
|
2045 |
+
ㄷ
|
2046 |
+
ㄹ
|
2047 |
+
ㅁ
|
2048 |
+
ㅂ
|
2049 |
+
ㅅ
|
2050 |
+
ㅈ
|
2051 |
+
ㅍ
|
2052 |
+
ㅎ
|
2053 |
+
ㅏ
|
2054 |
+
ㅓ
|
2055 |
+
ㅗ
|
2056 |
+
ㅜ
|
2057 |
+
ㅡ
|
2058 |
+
ㅣ
|
2059 |
+
㗎
|
2060 |
+
가
|
2061 |
+
각
|
2062 |
+
간
|
2063 |
+
갈
|
2064 |
+
감
|
2065 |
+
갑
|
2066 |
+
갓
|
2067 |
+
갔
|
2068 |
+
강
|
2069 |
+
같
|
2070 |
+
개
|
2071 |
+
거
|
2072 |
+
건
|
2073 |
+
걸
|
2074 |
+
겁
|
2075 |
+
것
|
2076 |
+
겉
|
2077 |
+
게
|
2078 |
+
겠
|
2079 |
+
겨
|
2080 |
+
결
|
2081 |
+
겼
|
2082 |
+
경
|
2083 |
+
계
|
2084 |
+
고
|
2085 |
+
곤
|
2086 |
+
골
|
2087 |
+
곱
|
2088 |
+
공
|
2089 |
+
과
|
2090 |
+
관
|
2091 |
+
광
|
2092 |
+
교
|
2093 |
+
구
|
2094 |
+
국
|
2095 |
+
굴
|
2096 |
+
귀
|
2097 |
+
귄
|
2098 |
+
그
|
2099 |
+
근
|
2100 |
+
글
|
2101 |
+
금
|
2102 |
+
기
|
2103 |
+
긴
|
2104 |
+
길
|
2105 |
+
까
|
2106 |
+
깍
|
2107 |
+
깔
|
2108 |
+
깜
|
2109 |
+
깨
|
2110 |
+
께
|
2111 |
+
꼬
|
2112 |
+
꼭
|
2113 |
+
꽃
|
2114 |
+
꾸
|
2115 |
+
꿔
|
2116 |
+
끔
|
2117 |
+
끗
|
2118 |
+
끝
|
2119 |
+
끼
|
2120 |
+
나
|
2121 |
+
난
|
2122 |
+
날
|
2123 |
+
남
|
2124 |
+
납
|
2125 |
+
내
|
2126 |
+
냐
|
2127 |
+
냥
|
2128 |
+
너
|
2129 |
+
넘
|
2130 |
+
넣
|
2131 |
+
네
|
2132 |
+
녁
|
2133 |
+
년
|
2134 |
+
녕
|
2135 |
+
노
|
2136 |
+
녹
|
2137 |
+
놀
|
2138 |
+
누
|
2139 |
+
눈
|
2140 |
+
느
|
2141 |
+
는
|
2142 |
+
늘
|
2143 |
+
니
|
2144 |
+
님
|
2145 |
+
닙
|
2146 |
+
다
|
2147 |
+
닥
|
2148 |
+
단
|
2149 |
+
달
|
2150 |
+
닭
|
2151 |
+
당
|
2152 |
+
대
|
2153 |
+
더
|
2154 |
+
덕
|
2155 |
+
던
|
2156 |
+
덥
|
2157 |
+
데
|
2158 |
+
도
|
2159 |
+
독
|
2160 |
+
동
|
2161 |
+
돼
|
2162 |
+
됐
|
2163 |
+
되
|
2164 |
+
된
|
2165 |
+
될
|
2166 |
+
두
|
2167 |
+
둑
|
2168 |
+
둥
|
2169 |
+
드
|
2170 |
+
들
|
2171 |
+
등
|
2172 |
+
디
|
2173 |
+
따
|
2174 |
+
딱
|
2175 |
+
딸
|
2176 |
+
땅
|
2177 |
+
때
|
2178 |
+
떤
|
2179 |
+
떨
|
2180 |
+
떻
|
2181 |
+
또
|
2182 |
+
똑
|
2183 |
+
뚱
|
2184 |
+
뛰
|
2185 |
+
뜻
|
2186 |
+
띠
|
2187 |
+
라
|
2188 |
+
락
|
2189 |
+
란
|
2190 |
+
람
|
2191 |
+
랍
|
2192 |
+
랑
|
2193 |
+
래
|
2194 |
+
랜
|
2195 |
+
러
|
2196 |
+
런
|
2197 |
+
럼
|
2198 |
+
렇
|
2199 |
+
레
|
2200 |
+
려
|
2201 |
+
력
|
2202 |
+
렵
|
2203 |
+
렸
|
2204 |
+
로
|
2205 |
+
록
|
2206 |
+
롬
|
2207 |
+
루
|
2208 |
+
르
|
2209 |
+
른
|
2210 |
+
를
|
2211 |
+
름
|
2212 |
+
릉
|
2213 |
+
리
|
2214 |
+
릴
|
2215 |
+
림
|
2216 |
+
마
|
2217 |
+
막
|
2218 |
+
만
|
2219 |
+
많
|
2220 |
+
말
|
2221 |
+
맑
|
2222 |
+
맙
|
2223 |
+
맛
|
2224 |
+
매
|
2225 |
+
머
|
2226 |
+
먹
|
2227 |
+
멍
|
2228 |
+
메
|
2229 |
+
면
|
2230 |
+
명
|
2231 |
+
몇
|
2232 |
+
모
|
2233 |
+
목
|
2234 |
+
몸
|
2235 |
+
못
|
2236 |
+
무
|
2237 |
+
문
|
2238 |
+
물
|
2239 |
+
뭐
|
2240 |
+
뭘
|
2241 |
+
미
|
2242 |
+
민
|
2243 |
+
밌
|
2244 |
+
밑
|
2245 |
+
바
|
2246 |
+
박
|
2247 |
+
밖
|
2248 |
+
반
|
2249 |
+
받
|
2250 |
+
발
|
2251 |
+
밤
|
2252 |
+
밥
|
2253 |
+
방
|
2254 |
+
배
|
2255 |
+
백
|
2256 |
+
밸
|
2257 |
+
뱀
|
2258 |
+
버
|
2259 |
+
번
|
2260 |
+
벌
|
2261 |
+
벚
|
2262 |
+
베
|
2263 |
+
벼
|
2264 |
+
벽
|
2265 |
+
별
|
2266 |
+
병
|
2267 |
+
보
|
2268 |
+
복
|
2269 |
+
본
|
2270 |
+
볼
|
2271 |
+
봐
|
2272 |
+
봤
|
2273 |
+
부
|
2274 |
+
분
|
2275 |
+
불
|
2276 |
+
비
|
2277 |
+
빔
|
2278 |
+
빛
|
2279 |
+
빠
|
2280 |
+
빨
|
2281 |
+
뼈
|
2282 |
+
뽀
|
2283 |
+
뿅
|
2284 |
+
쁘
|
2285 |
+
사
|
2286 |
+
산
|
2287 |
+
살
|
2288 |
+
삼
|
2289 |
+
샀
|
2290 |
+
상
|
2291 |
+
새
|
2292 |
+
색
|
2293 |
+
생
|
2294 |
+
서
|
2295 |
+
선
|
2296 |
+
설
|
2297 |
+
섭
|
2298 |
+
섰
|
2299 |
+
성
|
2300 |
+
세
|
2301 |
+
셔
|
2302 |
+
션
|
2303 |
+
셨
|
2304 |
+
소
|
2305 |
+
속
|
2306 |
+
손
|
2307 |
+
송
|
2308 |
+
수
|
2309 |
+
숙
|
2310 |
+
순
|
2311 |
+
술
|
2312 |
+
숫
|
2313 |
+
숭
|
2314 |
+
숲
|
2315 |
+
쉬
|
2316 |
+
쉽
|
2317 |
+
스
|
2318 |
+
슨
|
2319 |
+
습
|
2320 |
+
슷
|
2321 |
+
시
|
2322 |
+
식
|
2323 |
+
신
|
2324 |
+
실
|
2325 |
+
싫
|
2326 |
+
심
|
2327 |
+
십
|
2328 |
+
싶
|
2329 |
+
싸
|
2330 |
+
써
|
2331 |
+
쓰
|
2332 |
+
쓴
|
2333 |
+
씌
|
2334 |
+
씨
|
2335 |
+
씩
|
2336 |
+
씬
|
2337 |
+
아
|
2338 |
+
악
|
2339 |
+
안
|
2340 |
+
않
|
2341 |
+
알
|
2342 |
+
야
|
2343 |
+
약
|
2344 |
+
얀
|
2345 |
+
양
|
2346 |
+
얘
|
2347 |
+
어
|
2348 |
+
언
|
2349 |
+
얼
|
2350 |
+
엄
|
2351 |
+
업
|
2352 |
+
없
|
2353 |
+
었
|
2354 |
+
엉
|
2355 |
+
에
|
2356 |
+
여
|
2357 |
+
역
|
2358 |
+
연
|
2359 |
+
염
|
2360 |
+
엽
|
2361 |
+
영
|
2362 |
+
옆
|
2363 |
+
예
|
2364 |
+
옛
|
2365 |
+
오
|
2366 |
+
온
|
2367 |
+
올
|
2368 |
+
옷
|
2369 |
+
옹
|
2370 |
+
와
|
2371 |
+
왔
|
2372 |
+
왜
|
2373 |
+
요
|
2374 |
+
욕
|
2375 |
+
용
|
2376 |
+
우
|
2377 |
+
운
|
2378 |
+
울
|
2379 |
+
웃
|
2380 |
+
워
|
2381 |
+
원
|
2382 |
+
월
|
2383 |
+
웠
|
2384 |
+
위
|
2385 |
+
윙
|
2386 |
+
유
|
2387 |
+
육
|
2388 |
+
윤
|
2389 |
+
으
|
2390 |
+
은
|
2391 |
+
을
|
2392 |
+
음
|
2393 |
+
응
|
2394 |
+
의
|
2395 |
+
이
|
2396 |
+
익
|
2397 |
+
인
|
2398 |
+
일
|
2399 |
+
읽
|
2400 |
+
임
|
2401 |
+
입
|
2402 |
+
있
|
2403 |
+
자
|
2404 |
+
작
|
2405 |
+
잔
|
2406 |
+
잖
|
2407 |
+
잘
|
2408 |
+
잡
|
2409 |
+
잤
|
2410 |
+
장
|
2411 |
+
재
|
2412 |
+
저
|
2413 |
+
전
|
2414 |
+
점
|
2415 |
+
정
|
2416 |
+
제
|
2417 |
+
져
|
2418 |
+
졌
|
2419 |
+
조
|
2420 |
+
족
|
2421 |
+
좀
|
2422 |
+
종
|
2423 |
+
좋
|
2424 |
+
죠
|
2425 |
+
주
|
2426 |
+
준
|
2427 |
+
줄
|
2428 |
+
중
|
2429 |
+
줘
|
2430 |
+
즈
|
2431 |
+
즐
|
2432 |
+
즘
|
2433 |
+
지
|
2434 |
+
진
|
2435 |
+
집
|
2436 |
+
짜
|
2437 |
+
짝
|
2438 |
+
쩌
|
2439 |
+
쪼
|
2440 |
+
쪽
|
2441 |
+
쫌
|
2442 |
+
쭈
|
2443 |
+
쯔
|
2444 |
+
찌
|
2445 |
+
찍
|
2446 |
+
차
|
2447 |
+
착
|
2448 |
+
찾
|
2449 |
+
책
|
2450 |
+
처
|
2451 |
+
천
|
2452 |
+
철
|
2453 |
+
체
|
2454 |
+
쳐
|
2455 |
+
쳤
|
2456 |
+
초
|
2457 |
+
촌
|
2458 |
+
추
|
2459 |
+
출
|
2460 |
+
춤
|
2461 |
+
춥
|
2462 |
+
춰
|
2463 |
+
치
|
2464 |
+
친
|
2465 |
+
칠
|
2466 |
+
침
|
2467 |
+
칩
|
2468 |
+
칼
|
2469 |
+
커
|
2470 |
+
켓
|
2471 |
+
코
|
2472 |
+
콩
|
2473 |
+
쿠
|
2474 |
+
퀴
|
2475 |
+
크
|
2476 |
+
큰
|
2477 |
+
큽
|
2478 |
+
키
|
2479 |
+
킨
|
2480 |
+
타
|
2481 |
+
태
|
2482 |
+
터
|
2483 |
+
턴
|
2484 |
+
털
|
2485 |
+
테
|
2486 |
+
토
|
2487 |
+
통
|
2488 |
+
투
|
2489 |
+
트
|
2490 |
+
특
|
2491 |
+
튼
|
2492 |
+
틀
|
2493 |
+
티
|
2494 |
+
팀
|
2495 |
+
파
|
2496 |
+
팔
|
2497 |
+
패
|
2498 |
+
페
|
2499 |
+
펜
|
2500 |
+
펭
|
2501 |
+
평
|
2502 |
+
포
|
2503 |
+
폭
|
2504 |
+
표
|
2505 |
+
품
|
2506 |
+
풍
|
2507 |
+
프
|
2508 |
+
플
|
2509 |
+
피
|
2510 |
+
필
|
2511 |
+
하
|
2512 |
+
학
|
2513 |
+
한
|
2514 |
+
할
|
2515 |
+
함
|
2516 |
+
합
|
2517 |
+
항
|
2518 |
+
해
|
2519 |
+
햇
|
2520 |
+
했
|
2521 |
+
행
|
2522 |
+
허
|
2523 |
+
험
|
2524 |
+
형
|
2525 |
+
혜
|
2526 |
+
호
|
2527 |
+
혼
|
2528 |
+
홀
|
2529 |
+
화
|
2530 |
+
회
|
2531 |
+
획
|
2532 |
+
후
|
2533 |
+
휴
|
2534 |
+
흐
|
2535 |
+
흔
|
2536 |
+
희
|
2537 |
+
히
|
2538 |
+
힘
|
2539 |
+
ﷺ
|
2540 |
+
ﷻ
|
2541 |
+
!
|
2542 |
+
,
|
2543 |
+
?
|
2544 |
+
�
|
2545 |
+
𠮶
|
data/librispeech_pc_test_clean_cross_sentence.lst
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.cfm import CFM
|
2 |
+
|
3 |
+
from model.backbones.unett import UNetT
|
4 |
+
from model.backbones.dit import DiT
|
5 |
+
from model.backbones.mmdit import MMDiT
|
6 |
+
|
7 |
+
from model.trainer import Trainer
|
model/backbones/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Backbones quick introduction
|
2 |
+
|
3 |
+
|
4 |
+
### unett.py
|
5 |
+
- flat unet transformer
|
6 |
+
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
7 |
+
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
8 |
+
|
9 |
+
### dit.py
|
10 |
+
- adaln-zero dit
|
11 |
+
- embedded timestep as condition
|
12 |
+
- concatted noised_input + masked_cond + embedded_text, linear proj in
|
13 |
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
14 |
+
- possible long skip connection (first layer to last layer)
|
15 |
+
|
16 |
+
### mmdit.py
|
17 |
+
- sd3 structure
|
18 |
+
- timestep as condition
|
19 |
+
- left stream: text embedded and applied a abs pos emb
|
20 |
+
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
model/backbones/dit.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from einops import repeat
|
17 |
+
|
18 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
+
|
20 |
+
from model.modules import (
|
21 |
+
TimestepEmbedding,
|
22 |
+
ConvNeXtV2Block,
|
23 |
+
ConvPositionEmbedding,
|
24 |
+
DiTBlock,
|
25 |
+
AdaLayerNormZero_Final,
|
26 |
+
precompute_freqs_cis, get_pos_embed_indices,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
# Text embedding
|
31 |
+
|
32 |
+
class TextEmbedding(nn.Module):
|
33 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
34 |
+
super().__init__()
|
35 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
+
|
37 |
+
if conv_layers > 0:
|
38 |
+
self.extra_modeling = True
|
39 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
40 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
41 |
+
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
42 |
+
else:
|
43 |
+
self.extra_modeling = False
|
44 |
+
|
45 |
+
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
46 |
+
batch, text_len = text.shape[0], text.shape[1]
|
47 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
48 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
49 |
+
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
50 |
+
|
51 |
+
if drop_text: # cfg for text
|
52 |
+
text = torch.zeros_like(text)
|
53 |
+
|
54 |
+
text = self.text_embed(text) # b n -> b n d
|
55 |
+
|
56 |
+
# possible extra modeling
|
57 |
+
if self.extra_modeling:
|
58 |
+
# sinus pos emb
|
59 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
60 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
61 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
62 |
+
text = text + text_pos_embed
|
63 |
+
|
64 |
+
# convnextv2 blocks
|
65 |
+
text = self.text_blocks(text)
|
66 |
+
|
67 |
+
return text
|
68 |
+
|
69 |
+
|
70 |
+
# noised input audio and context mixing embedding
|
71 |
+
|
72 |
+
class InputEmbedding(nn.Module):
|
73 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
74 |
+
super().__init__()
|
75 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
76 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
77 |
+
|
78 |
+
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
79 |
+
if drop_audio_cond: # cfg for cond audio
|
80 |
+
cond = torch.zeros_like(cond)
|
81 |
+
|
82 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
83 |
+
x = self.conv_pos_embed(x) + x
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
# Transformer backbone using DiT blocks
|
88 |
+
|
89 |
+
class DiT(nn.Module):
|
90 |
+
def __init__(self, *,
|
91 |
+
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
92 |
+
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
93 |
+
long_skip_connection = False,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.time_embed = TimestepEmbedding(dim)
|
98 |
+
if text_dim is None:
|
99 |
+
text_dim = mel_dim
|
100 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
101 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
102 |
+
|
103 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
104 |
+
|
105 |
+
self.dim = dim
|
106 |
+
self.depth = depth
|
107 |
+
|
108 |
+
self.transformer_blocks = nn.ModuleList(
|
109 |
+
[
|
110 |
+
DiTBlock(
|
111 |
+
dim = dim,
|
112 |
+
heads = heads,
|
113 |
+
dim_head = dim_head,
|
114 |
+
ff_mult = ff_mult,
|
115 |
+
dropout = dropout
|
116 |
+
)
|
117 |
+
for _ in range(depth)
|
118 |
+
]
|
119 |
+
)
|
120 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
|
121 |
+
|
122 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
123 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
124 |
+
|
125 |
+
def forward(
|
126 |
+
self,
|
127 |
+
x: float['b n d'], # nosied input audio
|
128 |
+
cond: float['b n d'], # masked cond audio
|
129 |
+
text: int['b nt'], # text
|
130 |
+
time: float['b'] | float[''], # time step
|
131 |
+
drop_audio_cond, # cfg for cond audio
|
132 |
+
drop_text, # cfg for text
|
133 |
+
mask: bool['b n'] | None = None,
|
134 |
+
):
|
135 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
136 |
+
if time.ndim == 0:
|
137 |
+
time = repeat(time, ' -> b', b = batch)
|
138 |
+
|
139 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
140 |
+
t = self.time_embed(time)
|
141 |
+
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
142 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
143 |
+
|
144 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
145 |
+
|
146 |
+
if self.long_skip_connection is not None:
|
147 |
+
residual = x
|
148 |
+
|
149 |
+
for block in self.transformer_blocks:
|
150 |
+
x = block(x, t, mask = mask, rope = rope)
|
151 |
+
|
152 |
+
if self.long_skip_connection is not None:
|
153 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
|
154 |
+
|
155 |
+
x = self.norm_out(x, t)
|
156 |
+
output = self.proj_out(x)
|
157 |
+
|
158 |
+
return output
|
model/backbones/mmdit.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
from einops import repeat
|
16 |
+
|
17 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
18 |
+
|
19 |
+
from model.modules import (
|
20 |
+
TimestepEmbedding,
|
21 |
+
ConvPositionEmbedding,
|
22 |
+
MMDiTBlock,
|
23 |
+
AdaLayerNormZero_Final,
|
24 |
+
precompute_freqs_cis, get_pos_embed_indices,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
# text embedding
|
29 |
+
|
30 |
+
class TextEmbedding(nn.Module):
|
31 |
+
def __init__(self, out_dim, text_num_embeds):
|
32 |
+
super().__init__()
|
33 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
+
|
35 |
+
self.precompute_max_pos = 1024
|
36 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
37 |
+
|
38 |
+
def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
|
39 |
+
text = text + 1
|
40 |
+
if drop_text:
|
41 |
+
text = torch.zeros_like(text)
|
42 |
+
text = self.text_embed(text)
|
43 |
+
|
44 |
+
# sinus pos emb
|
45 |
+
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
46 |
+
batch_text_len = text.shape[1]
|
47 |
+
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
48 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
49 |
+
|
50 |
+
text = text + text_pos_embed
|
51 |
+
|
52 |
+
return text
|
53 |
+
|
54 |
+
|
55 |
+
# noised input & masked cond audio embedding
|
56 |
+
|
57 |
+
class AudioEmbedding(nn.Module):
|
58 |
+
def __init__(self, in_dim, out_dim):
|
59 |
+
super().__init__()
|
60 |
+
self.linear = nn.Linear(2 * in_dim, out_dim)
|
61 |
+
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
62 |
+
|
63 |
+
def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
|
64 |
+
if drop_audio_cond:
|
65 |
+
cond = torch.zeros_like(cond)
|
66 |
+
x = torch.cat((x, cond), dim = -1)
|
67 |
+
x = self.linear(x)
|
68 |
+
x = self.conv_pos_embed(x) + x
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
# Transformer backbone using MM-DiT blocks
|
73 |
+
|
74 |
+
class MMDiT(nn.Module):
|
75 |
+
def __init__(self, *,
|
76 |
+
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
77 |
+
text_num_embeds = 256, mel_dim = 100,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.time_embed = TimestepEmbedding(dim)
|
82 |
+
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
83 |
+
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
84 |
+
|
85 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
86 |
+
|
87 |
+
self.dim = dim
|
88 |
+
self.depth = depth
|
89 |
+
|
90 |
+
self.transformer_blocks = nn.ModuleList(
|
91 |
+
[
|
92 |
+
MMDiTBlock(
|
93 |
+
dim = dim,
|
94 |
+
heads = heads,
|
95 |
+
dim_head = dim_head,
|
96 |
+
dropout = dropout,
|
97 |
+
ff_mult = ff_mult,
|
98 |
+
context_pre_only = i == depth - 1,
|
99 |
+
)
|
100 |
+
for i in range(depth)
|
101 |
+
]
|
102 |
+
)
|
103 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
104 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
x: float['b n d'], # nosied input audio
|
109 |
+
cond: float['b n d'], # masked cond audio
|
110 |
+
text: int['b nt'], # text
|
111 |
+
time: float['b'] | float[''], # time step
|
112 |
+
drop_audio_cond, # cfg for cond audio
|
113 |
+
drop_text, # cfg for text
|
114 |
+
mask: bool['b n'] | None = None,
|
115 |
+
):
|
116 |
+
batch = x.shape[0]
|
117 |
+
if time.ndim == 0:
|
118 |
+
time = repeat(time, ' -> b', b = batch)
|
119 |
+
|
120 |
+
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
121 |
+
t = self.time_embed(time)
|
122 |
+
c = self.text_embed(text, drop_text = drop_text)
|
123 |
+
x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
|
124 |
+
|
125 |
+
seq_len = x.shape[1]
|
126 |
+
text_len = text.shape[1]
|
127 |
+
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
128 |
+
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
129 |
+
|
130 |
+
for block in self.transformer_blocks:
|
131 |
+
c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
|
132 |
+
|
133 |
+
x = self.norm_out(x, t)
|
134 |
+
output = self.proj_out(x)
|
135 |
+
|
136 |
+
return output
|
model/backbones/unett.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
from typing import Literal
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from einops import repeat, pack, unpack
|
18 |
+
|
19 |
+
from x_transformers import RMSNorm
|
20 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
21 |
+
|
22 |
+
from model.modules import (
|
23 |
+
TimestepEmbedding,
|
24 |
+
ConvNeXtV2Block,
|
25 |
+
ConvPositionEmbedding,
|
26 |
+
Attention,
|
27 |
+
AttnProcessor,
|
28 |
+
FeedForward,
|
29 |
+
precompute_freqs_cis, get_pos_embed_indices,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
# Text embedding
|
34 |
+
|
35 |
+
class TextEmbedding(nn.Module):
|
36 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
37 |
+
super().__init__()
|
38 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
+
|
40 |
+
if conv_layers > 0:
|
41 |
+
self.extra_modeling = True
|
42 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
43 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
44 |
+
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
45 |
+
else:
|
46 |
+
self.extra_modeling = False
|
47 |
+
|
48 |
+
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
49 |
+
batch, text_len = text.shape[0], text.shape[1]
|
50 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
51 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
52 |
+
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
53 |
+
|
54 |
+
if drop_text: # cfg for text
|
55 |
+
text = torch.zeros_like(text)
|
56 |
+
|
57 |
+
text = self.text_embed(text) # b n -> b n d
|
58 |
+
|
59 |
+
# possible extra modeling
|
60 |
+
if self.extra_modeling:
|
61 |
+
# sinus pos emb
|
62 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
63 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
64 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
65 |
+
text = text + text_pos_embed
|
66 |
+
|
67 |
+
# convnextv2 blocks
|
68 |
+
text = self.text_blocks(text)
|
69 |
+
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
# noised input audio and context mixing embedding
|
74 |
+
|
75 |
+
class InputEmbedding(nn.Module):
|
76 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
77 |
+
super().__init__()
|
78 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
79 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
80 |
+
|
81 |
+
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
82 |
+
if drop_audio_cond: # cfg for cond audio
|
83 |
+
cond = torch.zeros_like(cond)
|
84 |
+
|
85 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
86 |
+
x = self.conv_pos_embed(x) + x
|
87 |
+
return x
|
88 |
+
|
89 |
+
|
90 |
+
# Flat UNet Transformer backbone
|
91 |
+
|
92 |
+
class UNetT(nn.Module):
|
93 |
+
def __init__(self, *,
|
94 |
+
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
95 |
+
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
96 |
+
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
100 |
+
|
101 |
+
self.time_embed = TimestepEmbedding(dim)
|
102 |
+
if text_dim is None:
|
103 |
+
text_dim = mel_dim
|
104 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
105 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
106 |
+
|
107 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
108 |
+
|
109 |
+
# transformer layers & skip connections
|
110 |
+
|
111 |
+
self.dim = dim
|
112 |
+
self.skip_connect_type = skip_connect_type
|
113 |
+
needs_skip_proj = skip_connect_type == 'concat'
|
114 |
+
|
115 |
+
self.depth = depth
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
|
118 |
+
for idx in range(depth):
|
119 |
+
is_later_half = idx >= (depth // 2)
|
120 |
+
|
121 |
+
attn_norm = RMSNorm(dim)
|
122 |
+
attn = Attention(
|
123 |
+
processor = AttnProcessor(),
|
124 |
+
dim = dim,
|
125 |
+
heads = heads,
|
126 |
+
dim_head = dim_head,
|
127 |
+
dropout = dropout,
|
128 |
+
)
|
129 |
+
|
130 |
+
ff_norm = RMSNorm(dim)
|
131 |
+
ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
132 |
+
|
133 |
+
skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
|
134 |
+
|
135 |
+
self.layers.append(nn.ModuleList([
|
136 |
+
skip_proj,
|
137 |
+
attn_norm,
|
138 |
+
attn,
|
139 |
+
ff_norm,
|
140 |
+
ff,
|
141 |
+
]))
|
142 |
+
|
143 |
+
self.norm_out = RMSNorm(dim)
|
144 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
145 |
+
|
146 |
+
def forward(
|
147 |
+
self,
|
148 |
+
x: float['b n d'], # nosied input audio
|
149 |
+
cond: float['b n d'], # masked cond audio
|
150 |
+
text: int['b nt'], # text
|
151 |
+
time: float['b'] | float[''], # time step
|
152 |
+
drop_audio_cond, # cfg for cond audio
|
153 |
+
drop_text, # cfg for text
|
154 |
+
mask: bool['b n'] | None = None,
|
155 |
+
):
|
156 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
157 |
+
if time.ndim == 0:
|
158 |
+
time = repeat(time, ' -> b', b = batch)
|
159 |
+
|
160 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
161 |
+
t = self.time_embed(time)
|
162 |
+
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
163 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
164 |
+
|
165 |
+
# postfix time t to input x, [b n d] -> [b n+1 d]
|
166 |
+
x, ps = pack((t, x), 'b * d')
|
167 |
+
if mask is not None:
|
168 |
+
mask = F.pad(mask, (1, 0), value=1)
|
169 |
+
|
170 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
171 |
+
|
172 |
+
# flat unet transformer
|
173 |
+
skip_connect_type = self.skip_connect_type
|
174 |
+
skips = []
|
175 |
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
176 |
+
layer = idx + 1
|
177 |
+
|
178 |
+
# skip connection logic
|
179 |
+
is_first_half = layer <= (self.depth // 2)
|
180 |
+
is_later_half = not is_first_half
|
181 |
+
|
182 |
+
if is_first_half:
|
183 |
+
skips.append(x)
|
184 |
+
|
185 |
+
if is_later_half:
|
186 |
+
skip = skips.pop()
|
187 |
+
if skip_connect_type == 'concat':
|
188 |
+
x = torch.cat((x, skip), dim = -1)
|
189 |
+
x = maybe_skip_proj(x)
|
190 |
+
elif skip_connect_type == 'add':
|
191 |
+
x = x + skip
|
192 |
+
|
193 |
+
# attention and feedforward blocks
|
194 |
+
x = attn(attn_norm(x), rope = rope, mask = mask) + x
|
195 |
+
x = ff(ff_norm(x)) + x
|
196 |
+
|
197 |
+
assert len(skips) == 0
|
198 |
+
|
199 |
+
_, x = unpack(self.norm_out(x), ps, 'b * d')
|
200 |
+
|
201 |
+
return self.proj_out(x)
|
model/cfm.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
from typing import Callable
|
12 |
+
from random import random
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.nn.utils.rnn import pad_sequence
|
18 |
+
|
19 |
+
from torchdiffeq import odeint
|
20 |
+
|
21 |
+
from einops import rearrange
|
22 |
+
|
23 |
+
from model.modules import MelSpec
|
24 |
+
|
25 |
+
from model.utils import (
|
26 |
+
default, exists,
|
27 |
+
list_str_to_idx, list_str_to_tensor,
|
28 |
+
lens_to_mask, mask_from_frac_lengths,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class CFM(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
transformer: nn.Module,
|
36 |
+
sigma = 0.,
|
37 |
+
odeint_kwargs: dict = dict(
|
38 |
+
# atol = 1e-5,
|
39 |
+
# rtol = 1e-5,
|
40 |
+
method = 'euler' # 'midpoint'
|
41 |
+
),
|
42 |
+
audio_drop_prob = 0.3,
|
43 |
+
cond_drop_prob = 0.2,
|
44 |
+
num_channels = None,
|
45 |
+
mel_spec_module: nn.Module | None = None,
|
46 |
+
mel_spec_kwargs: dict = dict(),
|
47 |
+
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
|
48 |
+
vocab_char_map: dict[str: int] | None = None
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.frac_lengths_mask = frac_lengths_mask
|
53 |
+
|
54 |
+
# mel spec
|
55 |
+
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
56 |
+
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
57 |
+
self.num_channels = num_channels
|
58 |
+
|
59 |
+
# classifier-free guidance
|
60 |
+
self.audio_drop_prob = audio_drop_prob
|
61 |
+
self.cond_drop_prob = cond_drop_prob
|
62 |
+
|
63 |
+
# transformer
|
64 |
+
self.transformer = transformer
|
65 |
+
dim = transformer.dim
|
66 |
+
self.dim = dim
|
67 |
+
|
68 |
+
# conditional flow related
|
69 |
+
self.sigma = sigma
|
70 |
+
|
71 |
+
# sampling related
|
72 |
+
self.odeint_kwargs = odeint_kwargs
|
73 |
+
|
74 |
+
# vocab map for tokenization
|
75 |
+
self.vocab_char_map = vocab_char_map
|
76 |
+
|
77 |
+
@property
|
78 |
+
def device(self):
|
79 |
+
return next(self.parameters()).device
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def sample(
|
83 |
+
self,
|
84 |
+
cond: float['b n d'] | float['b nw'],
|
85 |
+
text: int['b nt'] | list[str],
|
86 |
+
duration: int | int['b'],
|
87 |
+
*,
|
88 |
+
lens: int['b'] | None = None,
|
89 |
+
steps = 32,
|
90 |
+
cfg_strength = 1.,
|
91 |
+
sway_sampling_coef = None,
|
92 |
+
seed: int | None = None,
|
93 |
+
max_duration = 4096,
|
94 |
+
vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
|
95 |
+
no_ref_audio = False,
|
96 |
+
duplicate_test = False,
|
97 |
+
t_inter = 0.1,
|
98 |
+
):
|
99 |
+
self.eval()
|
100 |
+
|
101 |
+
# raw wave
|
102 |
+
|
103 |
+
if cond.ndim == 2:
|
104 |
+
cond = self.mel_spec(cond)
|
105 |
+
cond = rearrange(cond, 'b d n -> b n d')
|
106 |
+
assert cond.shape[-1] == self.num_channels
|
107 |
+
|
108 |
+
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
109 |
+
if not exists(lens):
|
110 |
+
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
|
111 |
+
|
112 |
+
# text
|
113 |
+
|
114 |
+
if isinstance(text, list):
|
115 |
+
if exists(self.vocab_char_map):
|
116 |
+
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
117 |
+
else:
|
118 |
+
text = list_str_to_tensor(text).to(device)
|
119 |
+
assert text.shape[0] == batch
|
120 |
+
|
121 |
+
if exists(text):
|
122 |
+
text_lens = (text != -1).sum(dim = -1)
|
123 |
+
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
124 |
+
|
125 |
+
# duration
|
126 |
+
|
127 |
+
cond_mask = lens_to_mask(lens)
|
128 |
+
|
129 |
+
if isinstance(duration, int):
|
130 |
+
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
|
131 |
+
|
132 |
+
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
133 |
+
duration = duration.clamp(max = max_duration)
|
134 |
+
max_duration = duration.amax()
|
135 |
+
|
136 |
+
# duplicate test corner for inner time step oberservation
|
137 |
+
if duplicate_test:
|
138 |
+
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
|
139 |
+
|
140 |
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
|
141 |
+
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
|
142 |
+
cond_mask = rearrange(cond_mask, '... -> ... 1')
|
143 |
+
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
|
144 |
+
|
145 |
+
mask = lens_to_mask(duration)
|
146 |
+
|
147 |
+
# test for no ref audio
|
148 |
+
if no_ref_audio:
|
149 |
+
cond = torch.zeros_like(cond)
|
150 |
+
|
151 |
+
# neural ode
|
152 |
+
|
153 |
+
def fn(t, x):
|
154 |
+
# at each step, conditioning is fixed
|
155 |
+
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
156 |
+
|
157 |
+
# predict flow
|
158 |
+
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
|
159 |
+
if cfg_strength < 1e-5:
|
160 |
+
return pred
|
161 |
+
|
162 |
+
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
|
163 |
+
return pred + (pred - null_pred) * cfg_strength
|
164 |
+
|
165 |
+
# noise input
|
166 |
+
# to make sure batch inference result is same with different batch size, and for sure single inference
|
167 |
+
# still some difference maybe due to convolutional layers
|
168 |
+
y0 = []
|
169 |
+
for dur in duration:
|
170 |
+
if exists(seed):
|
171 |
+
torch.manual_seed(seed)
|
172 |
+
y0.append(torch.randn(dur, self.num_channels, device = self.device))
|
173 |
+
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
174 |
+
|
175 |
+
t_start = 0
|
176 |
+
|
177 |
+
# duplicate test corner for inner time step oberservation
|
178 |
+
if duplicate_test:
|
179 |
+
t_start = t_inter
|
180 |
+
y0 = (1 - t_start) * y0 + t_start * test_cond
|
181 |
+
steps = int(steps * (1 - t_start))
|
182 |
+
|
183 |
+
t = torch.linspace(t_start, 1, steps, device = self.device)
|
184 |
+
if sway_sampling_coef is not None:
|
185 |
+
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
186 |
+
|
187 |
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
188 |
+
|
189 |
+
sampled = trajectory[-1]
|
190 |
+
out = sampled
|
191 |
+
out = torch.where(cond_mask, cond, out)
|
192 |
+
|
193 |
+
if exists(vocoder):
|
194 |
+
out = rearrange(out, 'b n d -> b d n')
|
195 |
+
out = vocoder(out)
|
196 |
+
|
197 |
+
return out, trajectory
|
198 |
+
|
199 |
+
def forward(
|
200 |
+
self,
|
201 |
+
inp: float['b n d'] | float['b nw'], # mel or raw wave
|
202 |
+
text: int['b nt'] | list[str],
|
203 |
+
*,
|
204 |
+
lens: int['b'] | None = None,
|
205 |
+
noise_scheduler: str | None = None,
|
206 |
+
):
|
207 |
+
# handle raw wave
|
208 |
+
if inp.ndim == 2:
|
209 |
+
inp = self.mel_spec(inp)
|
210 |
+
inp = rearrange(inp, 'b d n -> b n d')
|
211 |
+
assert inp.shape[-1] == self.num_channels
|
212 |
+
|
213 |
+
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
214 |
+
|
215 |
+
# handle text as string
|
216 |
+
if isinstance(text, list):
|
217 |
+
if exists(self.vocab_char_map):
|
218 |
+
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
219 |
+
else:
|
220 |
+
text = list_str_to_tensor(text).to(device)
|
221 |
+
assert text.shape[0] == batch
|
222 |
+
|
223 |
+
# lens and mask
|
224 |
+
if not exists(lens):
|
225 |
+
lens = torch.full((batch,), seq_len, device = device)
|
226 |
+
|
227 |
+
mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
|
228 |
+
|
229 |
+
# get a random span to mask out for training conditionally
|
230 |
+
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
|
231 |
+
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
232 |
+
|
233 |
+
if exists(mask):
|
234 |
+
rand_span_mask &= mask
|
235 |
+
|
236 |
+
# mel is x1
|
237 |
+
x1 = inp
|
238 |
+
|
239 |
+
# x0 is gaussian noise
|
240 |
+
x0 = torch.randn_like(x1)
|
241 |
+
|
242 |
+
# time step
|
243 |
+
time = torch.rand((batch,), dtype = dtype, device = self.device)
|
244 |
+
# TODO. noise_scheduler
|
245 |
+
|
246 |
+
# sample xt (φ_t(x) in the paper)
|
247 |
+
t = rearrange(time, 'b -> b 1 1')
|
248 |
+
φ = (1 - t) * x0 + t * x1
|
249 |
+
flow = x1 - x0
|
250 |
+
|
251 |
+
# only predict what is within the random mask span for infilling
|
252 |
+
cond = torch.where(
|
253 |
+
rand_span_mask[..., None],
|
254 |
+
torch.zeros_like(x1), x1
|
255 |
+
)
|
256 |
+
|
257 |
+
# transformer and cfg training with a drop rate
|
258 |
+
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
259 |
+
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
|
260 |
+
drop_audio_cond = True
|
261 |
+
drop_text = True
|
262 |
+
else:
|
263 |
+
drop_text = False
|
264 |
+
|
265 |
+
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
266 |
+
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
267 |
+
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
|
268 |
+
|
269 |
+
# flow matching loss
|
270 |
+
loss = F.mse_loss(pred, flow, reduction = 'none')
|
271 |
+
loss = loss[rand_span_mask]
|
272 |
+
|
273 |
+
return loss.mean(), cond, pred
|
model/dataset.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset, Sampler
|
8 |
+
import torchaudio
|
9 |
+
from datasets import load_dataset, load_from_disk
|
10 |
+
from datasets import Dataset as Dataset_
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from model.modules import MelSpec
|
15 |
+
|
16 |
+
|
17 |
+
class HFDataset(Dataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
hf_dataset: Dataset,
|
21 |
+
target_sample_rate = 24_000,
|
22 |
+
n_mel_channels = 100,
|
23 |
+
hop_length = 256,
|
24 |
+
):
|
25 |
+
self.data = hf_dataset
|
26 |
+
self.target_sample_rate = target_sample_rate
|
27 |
+
self.hop_length = hop_length
|
28 |
+
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
29 |
+
|
30 |
+
def get_frame_len(self, index):
|
31 |
+
row = self.data[index]
|
32 |
+
audio = row['audio']['array']
|
33 |
+
sample_rate = row['audio']['sampling_rate']
|
34 |
+
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.data)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
row = self.data[index]
|
41 |
+
audio = row['audio']['array']
|
42 |
+
|
43 |
+
# logger.info(f"Audio shape: {audio.shape}")
|
44 |
+
|
45 |
+
sample_rate = row['audio']['sampling_rate']
|
46 |
+
duration = audio.shape[-1] / sample_rate
|
47 |
+
|
48 |
+
if duration > 30 or duration < 0.3:
|
49 |
+
return self.__getitem__((index + 1) % len(self.data))
|
50 |
+
|
51 |
+
audio_tensor = torch.from_numpy(audio).float()
|
52 |
+
|
53 |
+
if sample_rate != self.target_sample_rate:
|
54 |
+
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
55 |
+
audio_tensor = resampler(audio_tensor)
|
56 |
+
|
57 |
+
audio_tensor = rearrange(audio_tensor, 't -> 1 t')
|
58 |
+
|
59 |
+
mel_spec = self.mel_spectrogram(audio_tensor)
|
60 |
+
|
61 |
+
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
62 |
+
|
63 |
+
text = row['text']
|
64 |
+
|
65 |
+
return dict(
|
66 |
+
mel_spec = mel_spec,
|
67 |
+
text = text,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
class CustomDataset(Dataset):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
custom_dataset: Dataset,
|
75 |
+
durations = None,
|
76 |
+
target_sample_rate = 24_000,
|
77 |
+
hop_length = 256,
|
78 |
+
n_mel_channels = 100,
|
79 |
+
preprocessed_mel = False,
|
80 |
+
):
|
81 |
+
self.data = custom_dataset
|
82 |
+
self.durations = durations
|
83 |
+
self.target_sample_rate = target_sample_rate
|
84 |
+
self.hop_length = hop_length
|
85 |
+
self.preprocessed_mel = preprocessed_mel
|
86 |
+
if not preprocessed_mel:
|
87 |
+
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
|
88 |
+
|
89 |
+
def get_frame_len(self, index):
|
90 |
+
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
91 |
+
return self.durations[index] * self.target_sample_rate / self.hop_length
|
92 |
+
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.data)
|
96 |
+
|
97 |
+
def __getitem__(self, index):
|
98 |
+
row = self.data[index]
|
99 |
+
audio_path = row["audio_path"]
|
100 |
+
text = row["text"]
|
101 |
+
duration = row["duration"]
|
102 |
+
|
103 |
+
if self.preprocessed_mel:
|
104 |
+
mel_spec = torch.tensor(row["mel_spec"])
|
105 |
+
|
106 |
+
else:
|
107 |
+
audio, source_sample_rate = torchaudio.load(audio_path)
|
108 |
+
|
109 |
+
if duration > 30 or duration < 0.3:
|
110 |
+
return self.__getitem__((index + 1) % len(self.data))
|
111 |
+
|
112 |
+
if source_sample_rate != self.target_sample_rate:
|
113 |
+
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
114 |
+
audio = resampler(audio)
|
115 |
+
|
116 |
+
mel_spec = self.mel_spectrogram(audio)
|
117 |
+
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
118 |
+
|
119 |
+
return dict(
|
120 |
+
mel_spec = mel_spec,
|
121 |
+
text = text,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
# Dynamic Batch Sampler
|
126 |
+
|
127 |
+
class DynamicBatchSampler(Sampler[list[int]]):
|
128 |
+
""" Extension of Sampler that will do the following:
|
129 |
+
1. Change the batch size (essentially number of sequences)
|
130 |
+
in a batch to ensure that the total number of frames are less
|
131 |
+
than a certain threshold.
|
132 |
+
2. Make sure the padding efficiency in the batch is high.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
|
136 |
+
self.sampler = sampler
|
137 |
+
self.frames_threshold = frames_threshold
|
138 |
+
self.max_samples = max_samples
|
139 |
+
|
140 |
+
indices, batches = [], []
|
141 |
+
data_source = self.sampler.data_source
|
142 |
+
|
143 |
+
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
|
144 |
+
indices.append((idx, data_source.get_frame_len(idx)))
|
145 |
+
indices.sort(key=lambda elem : elem[1])
|
146 |
+
|
147 |
+
batch = []
|
148 |
+
batch_frames = 0
|
149 |
+
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
|
150 |
+
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
151 |
+
batch.append(idx)
|
152 |
+
batch_frames += frame_len
|
153 |
+
else:
|
154 |
+
if len(batch) > 0:
|
155 |
+
batches.append(batch)
|
156 |
+
if frame_len <= self.frames_threshold:
|
157 |
+
batch = [idx]
|
158 |
+
batch_frames = frame_len
|
159 |
+
else:
|
160 |
+
batch = []
|
161 |
+
batch_frames = 0
|
162 |
+
|
163 |
+
if not drop_last and len(batch) > 0:
|
164 |
+
batches.append(batch)
|
165 |
+
|
166 |
+
del indices
|
167 |
+
|
168 |
+
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
169 |
+
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
170 |
+
# e.g. for epoch n, use (random_seed + n)
|
171 |
+
random.seed(random_seed)
|
172 |
+
random.shuffle(batches)
|
173 |
+
|
174 |
+
self.batches = batches
|
175 |
+
|
176 |
+
def __iter__(self):
|
177 |
+
return iter(self.batches)
|
178 |
+
|
179 |
+
def __len__(self):
|
180 |
+
return len(self.batches)
|
181 |
+
|
182 |
+
|
183 |
+
# Load dataset
|
184 |
+
|
185 |
+
def load_dataset(
|
186 |
+
dataset_name: str,
|
187 |
+
tokenizer: str,
|
188 |
+
dataset_type: str = "CustomDataset",
|
189 |
+
audio_type: str = "raw",
|
190 |
+
mel_spec_kwargs: dict = dict()
|
191 |
+
) -> CustomDataset | HFDataset:
|
192 |
+
|
193 |
+
print("Loading dataset ...")
|
194 |
+
|
195 |
+
if dataset_type == "CustomDataset":
|
196 |
+
if audio_type == "raw":
|
197 |
+
try:
|
198 |
+
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
199 |
+
except:
|
200 |
+
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
201 |
+
preprocessed_mel = False
|
202 |
+
elif audio_type == "mel":
|
203 |
+
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
204 |
+
preprocessed_mel = True
|
205 |
+
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
|
206 |
+
data_dict = json.load(f)
|
207 |
+
durations = data_dict["duration"]
|
208 |
+
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
209 |
+
|
210 |
+
elif dataset_type == "HFDataset":
|
211 |
+
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
212 |
+
"May also the corresponding script cuz different dataset may have different format.")
|
213 |
+
pre, post = dataset_name.split("_")
|
214 |
+
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
|
215 |
+
|
216 |
+
return train_dataset
|
217 |
+
|
218 |
+
|
219 |
+
# collation
|
220 |
+
|
221 |
+
def collate_fn(batch):
|
222 |
+
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
|
223 |
+
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
224 |
+
max_mel_length = mel_lengths.amax()
|
225 |
+
|
226 |
+
padded_mel_specs = []
|
227 |
+
for spec in mel_specs: # TODO. maybe records mask for attention here
|
228 |
+
padding = (0, max_mel_length - spec.size(-1))
|
229 |
+
padded_spec = F.pad(spec, padding, value = 0)
|
230 |
+
padded_mel_specs.append(padded_spec)
|
231 |
+
|
232 |
+
mel_specs = torch.stack(padded_mel_specs)
|
233 |
+
|
234 |
+
text = [item['text'] for item in batch]
|
235 |
+
text_lengths = torch.LongTensor([len(item) for item in text])
|
236 |
+
|
237 |
+
return dict(
|
238 |
+
mel = mel_specs,
|
239 |
+
mel_lengths = mel_lengths,
|
240 |
+
text = text,
|
241 |
+
text_lengths = text_lengths,
|
242 |
+
)
|
model/ecapa_tdnn.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# just for speaker similarity evaluation, third-party code
|
2 |
+
|
3 |
+
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
|
4 |
+
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
''' Res2Conv1d + BatchNorm1d + ReLU
|
13 |
+
'''
|
14 |
+
|
15 |
+
class Res2Conv1dReluBn(nn.Module):
|
16 |
+
'''
|
17 |
+
in_channels == out_channels == channels
|
18 |
+
'''
|
19 |
+
|
20 |
+
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
21 |
+
super().__init__()
|
22 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
23 |
+
self.scale = scale
|
24 |
+
self.width = channels // scale
|
25 |
+
self.nums = scale if scale == 1 else scale - 1
|
26 |
+
|
27 |
+
self.convs = []
|
28 |
+
self.bns = []
|
29 |
+
for i in range(self.nums):
|
30 |
+
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
31 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
32 |
+
self.convs = nn.ModuleList(self.convs)
|
33 |
+
self.bns = nn.ModuleList(self.bns)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
out = []
|
37 |
+
spx = torch.split(x, self.width, 1)
|
38 |
+
for i in range(self.nums):
|
39 |
+
if i == 0:
|
40 |
+
sp = spx[i]
|
41 |
+
else:
|
42 |
+
sp = sp + spx[i]
|
43 |
+
# Order: conv -> relu -> bn
|
44 |
+
sp = self.convs[i](sp)
|
45 |
+
sp = self.bns[i](F.relu(sp))
|
46 |
+
out.append(sp)
|
47 |
+
if self.scale != 1:
|
48 |
+
out.append(spx[self.nums])
|
49 |
+
out = torch.cat(out, dim=1)
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
''' Conv1d + BatchNorm1d + ReLU
|
55 |
+
'''
|
56 |
+
|
57 |
+
class Conv1dReluBn(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
59 |
+
super().__init__()
|
60 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
61 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.bn(F.relu(self.conv(x)))
|
65 |
+
|
66 |
+
|
67 |
+
''' The SE connection of 1D case.
|
68 |
+
'''
|
69 |
+
|
70 |
+
class SE_Connect(nn.Module):
|
71 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
72 |
+
super().__init__()
|
73 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
74 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
out = x.mean(dim=2)
|
78 |
+
out = F.relu(self.linear1(out))
|
79 |
+
out = torch.sigmoid(self.linear2(out))
|
80 |
+
out = x * out.unsqueeze(2)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
''' SE-Res2Block of the ECAPA-TDNN architecture.
|
86 |
+
'''
|
87 |
+
|
88 |
+
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
89 |
+
# return nn.Sequential(
|
90 |
+
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
91 |
+
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
92 |
+
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
93 |
+
# SE_Connect(channels)
|
94 |
+
# )
|
95 |
+
|
96 |
+
class SE_Res2Block(nn.Module):
|
97 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
98 |
+
super().__init__()
|
99 |
+
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
100 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
|
101 |
+
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
102 |
+
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
103 |
+
|
104 |
+
self.shortcut = None
|
105 |
+
if in_channels != out_channels:
|
106 |
+
self.shortcut = nn.Conv1d(
|
107 |
+
in_channels=in_channels,
|
108 |
+
out_channels=out_channels,
|
109 |
+
kernel_size=1,
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
residual = x
|
114 |
+
if self.shortcut:
|
115 |
+
residual = self.shortcut(x)
|
116 |
+
|
117 |
+
x = self.Conv1dReluBn1(x)
|
118 |
+
x = self.Res2Conv1dReluBn(x)
|
119 |
+
x = self.Conv1dReluBn2(x)
|
120 |
+
x = self.SE_Connect(x)
|
121 |
+
|
122 |
+
return x + residual
|
123 |
+
|
124 |
+
|
125 |
+
''' Attentive weighted mean and standard deviation pooling.
|
126 |
+
'''
|
127 |
+
|
128 |
+
class AttentiveStatsPool(nn.Module):
|
129 |
+
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
130 |
+
super().__init__()
|
131 |
+
self.global_context_att = global_context_att
|
132 |
+
|
133 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
134 |
+
if global_context_att:
|
135 |
+
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
|
136 |
+
else:
|
137 |
+
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
|
138 |
+
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
|
142 |
+
if self.global_context_att:
|
143 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
144 |
+
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
145 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
146 |
+
else:
|
147 |
+
x_in = x
|
148 |
+
|
149 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
150 |
+
alpha = torch.tanh(self.linear1(x_in))
|
151 |
+
# alpha = F.relu(self.linear1(x_in))
|
152 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
153 |
+
mean = torch.sum(alpha * x, dim=2)
|
154 |
+
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
155 |
+
std = torch.sqrt(residuals.clamp(min=1e-9))
|
156 |
+
return torch.cat([mean, std], dim=1)
|
157 |
+
|
158 |
+
|
159 |
+
class ECAPA_TDNN(nn.Module):
|
160 |
+
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
|
161 |
+
feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
162 |
+
super().__init__()
|
163 |
+
|
164 |
+
self.feat_type = feat_type
|
165 |
+
self.feature_selection = feature_selection
|
166 |
+
self.update_extract = update_extract
|
167 |
+
self.sr = sr
|
168 |
+
|
169 |
+
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
|
170 |
+
try:
|
171 |
+
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
172 |
+
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
|
173 |
+
except:
|
174 |
+
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
|
175 |
+
|
176 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
177 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
178 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
179 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
180 |
+
|
181 |
+
self.feat_num = self.get_feat_num()
|
182 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
183 |
+
|
184 |
+
if feat_type != 'fbank' and feat_type != 'mfcc':
|
185 |
+
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
|
186 |
+
for name, param in self.feature_extract.named_parameters():
|
187 |
+
for freeze_val in freeze_list:
|
188 |
+
if freeze_val in name:
|
189 |
+
param.requires_grad = False
|
190 |
+
break
|
191 |
+
|
192 |
+
if not self.update_extract:
|
193 |
+
for param in self.feature_extract.parameters():
|
194 |
+
param.requires_grad = False
|
195 |
+
|
196 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
197 |
+
# self.channels = [channels] * 4 + [channels * 3]
|
198 |
+
self.channels = [channels] * 4 + [1536]
|
199 |
+
|
200 |
+
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
201 |
+
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
|
202 |
+
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
|
203 |
+
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
|
204 |
+
|
205 |
+
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
206 |
+
cat_channels = channels * 3
|
207 |
+
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
208 |
+
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
|
209 |
+
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
210 |
+
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
211 |
+
|
212 |
+
|
213 |
+
def get_feat_num(self):
|
214 |
+
self.feature_extract.eval()
|
215 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
216 |
+
with torch.no_grad():
|
217 |
+
features = self.feature_extract(wav)
|
218 |
+
select_feature = features[self.feature_selection]
|
219 |
+
if isinstance(select_feature, (list, tuple)):
|
220 |
+
return len(select_feature)
|
221 |
+
else:
|
222 |
+
return 1
|
223 |
+
|
224 |
+
def get_feat(self, x):
|
225 |
+
if self.update_extract:
|
226 |
+
x = self.feature_extract([sample for sample in x])
|
227 |
+
else:
|
228 |
+
with torch.no_grad():
|
229 |
+
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
|
230 |
+
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
231 |
+
else:
|
232 |
+
x = self.feature_extract([sample for sample in x])
|
233 |
+
|
234 |
+
if self.feat_type == 'fbank':
|
235 |
+
x = x.log()
|
236 |
+
|
237 |
+
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
238 |
+
x = x[self.feature_selection]
|
239 |
+
if isinstance(x, (list, tuple)):
|
240 |
+
x = torch.stack(x, dim=0)
|
241 |
+
else:
|
242 |
+
x = x.unsqueeze(0)
|
243 |
+
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
244 |
+
x = (norm_weights * x).sum(dim=0)
|
245 |
+
x = torch.transpose(x, 1, 2) + 1e-6
|
246 |
+
|
247 |
+
x = self.instance_norm(x)
|
248 |
+
return x
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
x = self.get_feat(x)
|
252 |
+
|
253 |
+
out1 = self.layer1(x)
|
254 |
+
out2 = self.layer2(out1)
|
255 |
+
out3 = self.layer3(out2)
|
256 |
+
out4 = self.layer4(out3)
|
257 |
+
|
258 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
259 |
+
out = F.relu(self.conv(out))
|
260 |
+
out = self.bn(self.pooling(out))
|
261 |
+
out = self.linear(out)
|
262 |
+
|
263 |
+
return out
|
264 |
+
|
265 |
+
|
266 |
+
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
267 |
+
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
|
268 |
+
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
|
model/modules.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
from typing import Optional
|
12 |
+
import math
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchaudio
|
18 |
+
|
19 |
+
from einops import rearrange
|
20 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
21 |
+
|
22 |
+
|
23 |
+
# raw wav to mel spec
|
24 |
+
|
25 |
+
class MelSpec(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
filter_length = 1024,
|
29 |
+
hop_length = 256,
|
30 |
+
win_length = 1024,
|
31 |
+
n_mel_channels = 100,
|
32 |
+
target_sample_rate = 24_000,
|
33 |
+
normalize = False,
|
34 |
+
power = 1,
|
35 |
+
norm = None,
|
36 |
+
center = True,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.n_mel_channels = n_mel_channels
|
40 |
+
|
41 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
42 |
+
sample_rate = target_sample_rate,
|
43 |
+
n_fft = filter_length,
|
44 |
+
win_length = win_length,
|
45 |
+
hop_length = hop_length,
|
46 |
+
n_mels = n_mel_channels,
|
47 |
+
power = power,
|
48 |
+
center = center,
|
49 |
+
normalized = normalize,
|
50 |
+
norm = norm,
|
51 |
+
)
|
52 |
+
|
53 |
+
self.register_buffer('dummy', torch.tensor(0), persistent = False)
|
54 |
+
|
55 |
+
def forward(self, inp):
|
56 |
+
if len(inp.shape) == 3:
|
57 |
+
inp = rearrange(inp, 'b 1 nw -> b nw')
|
58 |
+
|
59 |
+
assert len(inp.shape) == 2
|
60 |
+
|
61 |
+
if self.dummy.device != inp.device:
|
62 |
+
self.to(inp.device)
|
63 |
+
|
64 |
+
mel = self.mel_stft(inp)
|
65 |
+
mel = mel.clamp(min = 1e-5).log()
|
66 |
+
return mel
|
67 |
+
|
68 |
+
|
69 |
+
# sinusoidal position embedding
|
70 |
+
|
71 |
+
class SinusPositionEmbedding(nn.Module):
|
72 |
+
def __init__(self, dim):
|
73 |
+
super().__init__()
|
74 |
+
self.dim = dim
|
75 |
+
|
76 |
+
def forward(self, x, scale=1000):
|
77 |
+
device = x.device
|
78 |
+
half_dim = self.dim // 2
|
79 |
+
emb = math.log(10000) / (half_dim - 1)
|
80 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
81 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
82 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
83 |
+
return emb
|
84 |
+
|
85 |
+
|
86 |
+
# convolutional position embedding
|
87 |
+
|
88 |
+
class ConvPositionEmbedding(nn.Module):
|
89 |
+
def __init__(self, dim, kernel_size = 31, groups = 16):
|
90 |
+
super().__init__()
|
91 |
+
assert kernel_size % 2 != 0
|
92 |
+
self.conv1d = nn.Sequential(
|
93 |
+
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
94 |
+
nn.Mish(),
|
95 |
+
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
96 |
+
nn.Mish(),
|
97 |
+
)
|
98 |
+
|
99 |
+
def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
|
100 |
+
if mask is not None:
|
101 |
+
mask = mask[..., None]
|
102 |
+
x = x.masked_fill(~mask, 0.)
|
103 |
+
|
104 |
+
x = rearrange(x, 'b n d -> b d n')
|
105 |
+
x = self.conv1d(x)
|
106 |
+
out = rearrange(x, 'b d n -> b n d')
|
107 |
+
|
108 |
+
if mask is not None:
|
109 |
+
out = out.masked_fill(~mask, 0.)
|
110 |
+
|
111 |
+
return out
|
112 |
+
|
113 |
+
|
114 |
+
# rotary positional embedding related
|
115 |
+
|
116 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
|
117 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
118 |
+
# has some connection to NTK literature
|
119 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
120 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
121 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
122 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
123 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
124 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
125 |
+
freqs_cos = torch.cos(freqs) # real part
|
126 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
127 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
128 |
+
|
129 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.):
|
130 |
+
# length = length if isinstance(length, int) else length.max()
|
131 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
132 |
+
pos = start.unsqueeze(1) + (
|
133 |
+
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
|
134 |
+
scale.unsqueeze(1)).long()
|
135 |
+
# avoid extra long error.
|
136 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
137 |
+
return pos
|
138 |
+
|
139 |
+
|
140 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
141 |
+
|
142 |
+
class GRN(nn.Module):
|
143 |
+
def __init__(self, dim):
|
144 |
+
super().__init__()
|
145 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
146 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
150 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
151 |
+
return self.gamma * (x * Nx) + self.beta + x
|
152 |
+
|
153 |
+
|
154 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
155 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
156 |
+
|
157 |
+
class ConvNeXtV2Block(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
dim: int,
|
161 |
+
intermediate_dim: int,
|
162 |
+
dilation: int = 1,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
padding = (dilation * (7 - 1)) // 2
|
166 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
|
167 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
168 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
169 |
+
self.act = nn.GELU()
|
170 |
+
self.grn = GRN(intermediate_dim)
|
171 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
172 |
+
|
173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
174 |
+
residual = x
|
175 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
176 |
+
x = self.dwconv(x)
|
177 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
178 |
+
x = self.norm(x)
|
179 |
+
x = self.pwconv1(x)
|
180 |
+
x = self.act(x)
|
181 |
+
x = self.grn(x)
|
182 |
+
x = self.pwconv2(x)
|
183 |
+
return residual + x
|
184 |
+
|
185 |
+
|
186 |
+
# AdaLayerNormZero
|
187 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
188 |
+
|
189 |
+
class AdaLayerNormZero(nn.Module):
|
190 |
+
def __init__(self, dim):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.silu = nn.SiLU()
|
194 |
+
self.linear = nn.Linear(dim, dim * 6)
|
195 |
+
|
196 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
197 |
+
|
198 |
+
def forward(self, x, emb = None):
|
199 |
+
emb = self.linear(self.silu(emb))
|
200 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
201 |
+
|
202 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
203 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
204 |
+
|
205 |
+
|
206 |
+
# AdaLayerNormZero for final layer
|
207 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
208 |
+
|
209 |
+
class AdaLayerNormZero_Final(nn.Module):
|
210 |
+
def __init__(self, dim):
|
211 |
+
super().__init__()
|
212 |
+
|
213 |
+
self.silu = nn.SiLU()
|
214 |
+
self.linear = nn.Linear(dim, dim * 2)
|
215 |
+
|
216 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
217 |
+
|
218 |
+
def forward(self, x, emb):
|
219 |
+
emb = self.linear(self.silu(emb))
|
220 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
221 |
+
|
222 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
223 |
+
return x
|
224 |
+
|
225 |
+
|
226 |
+
# FeedForward
|
227 |
+
|
228 |
+
class FeedForward(nn.Module):
|
229 |
+
def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
|
230 |
+
super().__init__()
|
231 |
+
inner_dim = int(dim * mult)
|
232 |
+
dim_out = dim_out if dim_out is not None else dim
|
233 |
+
|
234 |
+
activation = nn.GELU(approximate=approximate)
|
235 |
+
project_in = nn.Sequential(
|
236 |
+
nn.Linear(dim, inner_dim),
|
237 |
+
activation
|
238 |
+
)
|
239 |
+
self.ff = nn.Sequential(
|
240 |
+
project_in,
|
241 |
+
nn.Dropout(dropout),
|
242 |
+
nn.Linear(inner_dim, dim_out)
|
243 |
+
)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
return self.ff(x)
|
247 |
+
|
248 |
+
|
249 |
+
# Attention with possible joint part
|
250 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
251 |
+
|
252 |
+
class Attention(nn.Module):
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
processor: JointAttnProcessor | AttnProcessor,
|
256 |
+
dim: int,
|
257 |
+
heads: int = 8,
|
258 |
+
dim_head: int = 64,
|
259 |
+
dropout: float = 0.0,
|
260 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
261 |
+
context_pre_only = None,
|
262 |
+
):
|
263 |
+
super().__init__()
|
264 |
+
|
265 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
266 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
267 |
+
|
268 |
+
self.processor = processor
|
269 |
+
|
270 |
+
self.dim = dim
|
271 |
+
self.heads = heads
|
272 |
+
self.inner_dim = dim_head * heads
|
273 |
+
self.dropout = dropout
|
274 |
+
|
275 |
+
self.context_dim = context_dim
|
276 |
+
self.context_pre_only = context_pre_only
|
277 |
+
|
278 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
279 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
280 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
281 |
+
|
282 |
+
if self.context_dim is not None:
|
283 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
284 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
285 |
+
if self.context_pre_only is not None:
|
286 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
287 |
+
|
288 |
+
self.to_out = nn.ModuleList([])
|
289 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
290 |
+
self.to_out.append(nn.Dropout(dropout))
|
291 |
+
|
292 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
293 |
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
294 |
+
|
295 |
+
def forward(
|
296 |
+
self,
|
297 |
+
x: float['b n d'], # noised input x
|
298 |
+
c: float['b n d'] = None, # context c
|
299 |
+
mask: bool['b n'] | None = None,
|
300 |
+
rope = None, # rotary position embedding for x
|
301 |
+
c_rope = None, # rotary position embedding for c
|
302 |
+
) -> torch.Tensor:
|
303 |
+
if c is not None:
|
304 |
+
return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
|
305 |
+
else:
|
306 |
+
return self.processor(self, x, mask = mask, rope = rope)
|
307 |
+
|
308 |
+
|
309 |
+
# Attention processor
|
310 |
+
|
311 |
+
class AttnProcessor:
|
312 |
+
def __init__(self):
|
313 |
+
pass
|
314 |
+
|
315 |
+
def __call__(
|
316 |
+
self,
|
317 |
+
attn: Attention,
|
318 |
+
x: float['b n d'], # noised input x
|
319 |
+
mask: bool['b n'] | None = None,
|
320 |
+
rope = None, # rotary position embedding
|
321 |
+
) -> torch.FloatTensor:
|
322 |
+
|
323 |
+
batch_size = x.shape[0]
|
324 |
+
|
325 |
+
# `sample` projections.
|
326 |
+
query = attn.to_q(x)
|
327 |
+
key = attn.to_k(x)
|
328 |
+
value = attn.to_v(x)
|
329 |
+
|
330 |
+
# apply rotary position embedding
|
331 |
+
if rope is not None:
|
332 |
+
freqs, xpos_scale = rope
|
333 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
334 |
+
|
335 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
336 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
337 |
+
|
338 |
+
# attention
|
339 |
+
inner_dim = key.shape[-1]
|
340 |
+
head_dim = inner_dim // attn.heads
|
341 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
342 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
343 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
344 |
+
|
345 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
346 |
+
if mask is not None:
|
347 |
+
attn_mask = mask
|
348 |
+
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
|
349 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
350 |
+
else:
|
351 |
+
attn_mask = None
|
352 |
+
|
353 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
354 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
355 |
+
x = x.to(query.dtype)
|
356 |
+
|
357 |
+
# linear proj
|
358 |
+
x = attn.to_out[0](x)
|
359 |
+
# dropout
|
360 |
+
x = attn.to_out[1](x)
|
361 |
+
|
362 |
+
if mask is not None:
|
363 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
364 |
+
x = x.masked_fill(~mask, 0.)
|
365 |
+
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
# Joint Attention processor for MM-DiT
|
370 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
371 |
+
|
372 |
+
class JointAttnProcessor:
|
373 |
+
def __init__(self):
|
374 |
+
pass
|
375 |
+
|
376 |
+
def __call__(
|
377 |
+
self,
|
378 |
+
attn: Attention,
|
379 |
+
x: float['b n d'], # noised input x
|
380 |
+
c: float['b nt d'] = None, # context c, here text
|
381 |
+
mask: bool['b n'] | None = None,
|
382 |
+
rope = None, # rotary position embedding for x
|
383 |
+
c_rope = None, # rotary position embedding for c
|
384 |
+
) -> torch.FloatTensor:
|
385 |
+
residual = x
|
386 |
+
|
387 |
+
batch_size = c.shape[0]
|
388 |
+
|
389 |
+
# `sample` projections.
|
390 |
+
query = attn.to_q(x)
|
391 |
+
key = attn.to_k(x)
|
392 |
+
value = attn.to_v(x)
|
393 |
+
|
394 |
+
# `context` projections.
|
395 |
+
c_query = attn.to_q_c(c)
|
396 |
+
c_key = attn.to_k_c(c)
|
397 |
+
c_value = attn.to_v_c(c)
|
398 |
+
|
399 |
+
# apply rope for context and noised input independently
|
400 |
+
if rope is not None:
|
401 |
+
freqs, xpos_scale = rope
|
402 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
403 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
404 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
405 |
+
if c_rope is not None:
|
406 |
+
freqs, xpos_scale = c_rope
|
407 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
408 |
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
409 |
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
410 |
+
|
411 |
+
# attention
|
412 |
+
query = torch.cat([query, c_query], dim=1)
|
413 |
+
key = torch.cat([key, c_key], dim=1)
|
414 |
+
value = torch.cat([value, c_value], dim=1)
|
415 |
+
|
416 |
+
inner_dim = key.shape[-1]
|
417 |
+
head_dim = inner_dim // attn.heads
|
418 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
419 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
420 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
421 |
+
|
422 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
423 |
+
if mask is not None:
|
424 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
|
425 |
+
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
|
426 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
427 |
+
else:
|
428 |
+
attn_mask = None
|
429 |
+
|
430 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
431 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
432 |
+
x = x.to(query.dtype)
|
433 |
+
|
434 |
+
# Split the attention outputs.
|
435 |
+
x, c = (
|
436 |
+
x[:, :residual.shape[1]],
|
437 |
+
x[:, residual.shape[1]:],
|
438 |
+
)
|
439 |
+
|
440 |
+
# linear proj
|
441 |
+
x = attn.to_out[0](x)
|
442 |
+
# dropout
|
443 |
+
x = attn.to_out[1](x)
|
444 |
+
if not attn.context_pre_only:
|
445 |
+
c = attn.to_out_c(c)
|
446 |
+
|
447 |
+
if mask is not None:
|
448 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
449 |
+
x = x.masked_fill(~mask, 0.)
|
450 |
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
451 |
+
|
452 |
+
return x, c
|
453 |
+
|
454 |
+
|
455 |
+
# DiT Block
|
456 |
+
|
457 |
+
class DiTBlock(nn.Module):
|
458 |
+
|
459 |
+
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
463 |
+
self.attn = Attention(
|
464 |
+
processor = AttnProcessor(),
|
465 |
+
dim = dim,
|
466 |
+
heads = heads,
|
467 |
+
dim_head = dim_head,
|
468 |
+
dropout = dropout,
|
469 |
+
)
|
470 |
+
|
471 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
472 |
+
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
473 |
+
|
474 |
+
def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
|
475 |
+
# pre-norm & modulation for attention input
|
476 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
477 |
+
|
478 |
+
# attention
|
479 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
480 |
+
|
481 |
+
# process attention output for input x
|
482 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
483 |
+
|
484 |
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
485 |
+
ff_output = self.ff(norm)
|
486 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
487 |
+
|
488 |
+
return x
|
489 |
+
|
490 |
+
|
491 |
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
492 |
+
|
493 |
+
class MMDiTBlock(nn.Module):
|
494 |
+
r"""
|
495 |
+
modified from diffusers/src/diffusers/models/attention.py
|
496 |
+
|
497 |
+
notes.
|
498 |
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
499 |
+
_x: noised input related. (right part)
|
500 |
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
501 |
+
"""
|
502 |
+
|
503 |
+
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
|
504 |
+
super().__init__()
|
505 |
+
|
506 |
+
self.context_pre_only = context_pre_only
|
507 |
+
|
508 |
+
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
509 |
+
self.attn_norm_x = AdaLayerNormZero(dim)
|
510 |
+
self.attn = Attention(
|
511 |
+
processor = JointAttnProcessor(),
|
512 |
+
dim = dim,
|
513 |
+
heads = heads,
|
514 |
+
dim_head = dim_head,
|
515 |
+
dropout = dropout,
|
516 |
+
context_dim = dim,
|
517 |
+
context_pre_only = context_pre_only,
|
518 |
+
)
|
519 |
+
|
520 |
+
if not context_pre_only:
|
521 |
+
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
522 |
+
self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
523 |
+
else:
|
524 |
+
self.ff_norm_c = None
|
525 |
+
self.ff_c = None
|
526 |
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
527 |
+
self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
528 |
+
|
529 |
+
def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
|
530 |
+
# pre-norm & modulation for attention input
|
531 |
+
if self.context_pre_only:
|
532 |
+
norm_c = self.attn_norm_c(c, t)
|
533 |
+
else:
|
534 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
535 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
536 |
+
|
537 |
+
# attention
|
538 |
+
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
539 |
+
|
540 |
+
# process attention output for context c
|
541 |
+
if self.context_pre_only:
|
542 |
+
c = None
|
543 |
+
else: # if not last layer
|
544 |
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
545 |
+
|
546 |
+
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
547 |
+
c_ff_output = self.ff_c(norm_c)
|
548 |
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
549 |
+
|
550 |
+
# process attention output for input x
|
551 |
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
552 |
+
|
553 |
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
554 |
+
x_ff_output = self.ff_x(norm_x)
|
555 |
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
556 |
+
|
557 |
+
return c, x
|
558 |
+
|
559 |
+
|
560 |
+
# time step conditioning embedding
|
561 |
+
|
562 |
+
class TimestepEmbedding(nn.Module):
|
563 |
+
def __init__(self, dim, freq_embed_dim=256):
|
564 |
+
super().__init__()
|
565 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
566 |
+
self.time_mlp = nn.Sequential(
|
567 |
+
nn.Linear(freq_embed_dim, dim),
|
568 |
+
nn.SiLU(),
|
569 |
+
nn.Linear(dim, dim)
|
570 |
+
)
|
571 |
+
|
572 |
+
def forward(self, timestep: float['b']):
|
573 |
+
time_hidden = self.time_embed(timestep)
|
574 |
+
time = self.time_mlp(time_hidden) # b d
|
575 |
+
return time
|
model/trainer.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import gc
|
5 |
+
from tqdm import tqdm
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.optim import AdamW
|
10 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
11 |
+
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from accelerate import Accelerator
|
16 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
17 |
+
|
18 |
+
from ema_pytorch import EMA
|
19 |
+
|
20 |
+
from model import CFM
|
21 |
+
from model.utils import exists, default
|
22 |
+
from model.dataset import DynamicBatchSampler, collate_fn
|
23 |
+
|
24 |
+
|
25 |
+
# trainer
|
26 |
+
|
27 |
+
class Trainer:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
model: CFM,
|
31 |
+
epochs,
|
32 |
+
learning_rate,
|
33 |
+
num_warmup_updates = 20000,
|
34 |
+
save_per_updates = 1000,
|
35 |
+
checkpoint_path = None,
|
36 |
+
batch_size = 32,
|
37 |
+
batch_size_type: str = "sample",
|
38 |
+
max_samples = 32,
|
39 |
+
grad_accumulation_steps = 1,
|
40 |
+
max_grad_norm = 1.0,
|
41 |
+
noise_scheduler: str | None = None,
|
42 |
+
duration_predictor: torch.nn.Module | None = None,
|
43 |
+
wandb_project = "test_e2-tts",
|
44 |
+
wandb_run_name = "test_run",
|
45 |
+
wandb_resume_id: str = None,
|
46 |
+
last_per_steps = None,
|
47 |
+
accelerate_kwargs: dict = dict(),
|
48 |
+
ema_kwargs: dict = dict()
|
49 |
+
):
|
50 |
+
|
51 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
52 |
+
|
53 |
+
self.accelerator = Accelerator(
|
54 |
+
log_with = "wandb",
|
55 |
+
kwargs_handlers = [ddp_kwargs],
|
56 |
+
gradient_accumulation_steps = grad_accumulation_steps,
|
57 |
+
**accelerate_kwargs
|
58 |
+
)
|
59 |
+
|
60 |
+
if exists(wandb_resume_id):
|
61 |
+
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
|
62 |
+
else:
|
63 |
+
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
|
64 |
+
self.accelerator.init_trackers(
|
65 |
+
project_name = wandb_project,
|
66 |
+
init_kwargs=init_kwargs,
|
67 |
+
config={"epochs": epochs,
|
68 |
+
"learning_rate": learning_rate,
|
69 |
+
"num_warmup_updates": num_warmup_updates,
|
70 |
+
"batch_size": batch_size,
|
71 |
+
"batch_size_type": batch_size_type,
|
72 |
+
"max_samples": max_samples,
|
73 |
+
"grad_accumulation_steps": grad_accumulation_steps,
|
74 |
+
"max_grad_norm": max_grad_norm,
|
75 |
+
"gpus": self.accelerator.num_processes,
|
76 |
+
"noise_scheduler": noise_scheduler}
|
77 |
+
)
|
78 |
+
|
79 |
+
self.model = model
|
80 |
+
|
81 |
+
if self.is_main:
|
82 |
+
self.ema_model = EMA(
|
83 |
+
model,
|
84 |
+
include_online_model = False,
|
85 |
+
**ema_kwargs
|
86 |
+
)
|
87 |
+
|
88 |
+
self.ema_model.to(self.accelerator.device)
|
89 |
+
|
90 |
+
self.epochs = epochs
|
91 |
+
self.num_warmup_updates = num_warmup_updates
|
92 |
+
self.save_per_updates = save_per_updates
|
93 |
+
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
94 |
+
self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
|
95 |
+
|
96 |
+
self.batch_size = batch_size
|
97 |
+
self.batch_size_type = batch_size_type
|
98 |
+
self.max_samples = max_samples
|
99 |
+
self.grad_accumulation_steps = grad_accumulation_steps
|
100 |
+
self.max_grad_norm = max_grad_norm
|
101 |
+
|
102 |
+
self.noise_scheduler = noise_scheduler
|
103 |
+
|
104 |
+
self.duration_predictor = duration_predictor
|
105 |
+
|
106 |
+
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
107 |
+
self.model, self.optimizer = self.accelerator.prepare(
|
108 |
+
self.model, self.optimizer
|
109 |
+
)
|
110 |
+
|
111 |
+
@property
|
112 |
+
def is_main(self):
|
113 |
+
return self.accelerator.is_main_process
|
114 |
+
|
115 |
+
def save_checkpoint(self, step, last=False):
|
116 |
+
self.accelerator.wait_for_everyone()
|
117 |
+
if self.is_main:
|
118 |
+
checkpoint = dict(
|
119 |
+
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
|
120 |
+
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
121 |
+
ema_model_state_dict = self.ema_model.state_dict(),
|
122 |
+
scheduler_state_dict = self.scheduler.state_dict(),
|
123 |
+
step = step
|
124 |
+
)
|
125 |
+
if not os.path.exists(self.checkpoint_path):
|
126 |
+
os.makedirs(self.checkpoint_path)
|
127 |
+
if last == True:
|
128 |
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
129 |
+
print(f"Saved last checkpoint at step {step}")
|
130 |
+
else:
|
131 |
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
132 |
+
|
133 |
+
def load_checkpoint(self):
|
134 |
+
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
|
135 |
+
return 0
|
136 |
+
|
137 |
+
self.accelerator.wait_for_everyone()
|
138 |
+
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
139 |
+
latest_checkpoint = "model_last.pt"
|
140 |
+
else:
|
141 |
+
latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
+
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
+
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
|
144 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
145 |
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
|
146 |
+
|
147 |
+
if self.is_main:
|
148 |
+
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
149 |
+
|
150 |
+
if self.scheduler:
|
151 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
152 |
+
|
153 |
+
step = checkpoint['step']
|
154 |
+
del checkpoint; gc.collect()
|
155 |
+
return step
|
156 |
+
|
157 |
+
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
158 |
+
|
159 |
+
if exists(resumable_with_seed):
|
160 |
+
generator = torch.Generator()
|
161 |
+
generator.manual_seed(resumable_with_seed)
|
162 |
+
else:
|
163 |
+
generator = None
|
164 |
+
|
165 |
+
if self.batch_size_type == "sample":
|
166 |
+
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
|
167 |
+
batch_size=self.batch_size, shuffle=True, generator=generator)
|
168 |
+
elif self.batch_size_type == "frame":
|
169 |
+
self.accelerator.even_batches = False
|
170 |
+
sampler = SequentialSampler(train_dataset)
|
171 |
+
batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
|
172 |
+
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
|
173 |
+
batch_sampler=batch_sampler)
|
174 |
+
else:
|
175 |
+
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
|
176 |
+
|
177 |
+
# accelerator.prepare() dispatches batches to devices;
|
178 |
+
# which means the length of dataloader calculated before, should consider the number of devices
|
179 |
+
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
180 |
+
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
181 |
+
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
182 |
+
decay_steps = total_steps - warmup_steps
|
183 |
+
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
184 |
+
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
185 |
+
self.scheduler = SequentialLR(self.optimizer,
|
186 |
+
schedulers=[warmup_scheduler, decay_scheduler],
|
187 |
+
milestones=[warmup_steps])
|
188 |
+
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
|
189 |
+
start_step = self.load_checkpoint()
|
190 |
+
global_step = start_step
|
191 |
+
|
192 |
+
if exists(resumable_with_seed):
|
193 |
+
orig_epoch_step = len(train_dataloader)
|
194 |
+
skipped_epoch = int(start_step // orig_epoch_step)
|
195 |
+
skipped_batch = start_step % orig_epoch_step
|
196 |
+
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
197 |
+
else:
|
198 |
+
skipped_epoch = 0
|
199 |
+
|
200 |
+
for epoch in range(skipped_epoch, self.epochs):
|
201 |
+
self.model.train()
|
202 |
+
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
203 |
+
progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
|
204 |
+
initial=skipped_batch, total=orig_epoch_step)
|
205 |
+
else:
|
206 |
+
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
|
207 |
+
|
208 |
+
for batch in progress_bar:
|
209 |
+
with self.accelerator.accumulate(self.model):
|
210 |
+
text_inputs = batch['text']
|
211 |
+
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
|
212 |
+
mel_lengths = batch["mel_lengths"]
|
213 |
+
|
214 |
+
# TODO. add duration predictor training
|
215 |
+
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
216 |
+
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
|
217 |
+
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
218 |
+
|
219 |
+
loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
|
220 |
+
self.accelerator.backward(loss)
|
221 |
+
|
222 |
+
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
223 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
224 |
+
|
225 |
+
self.optimizer.step()
|
226 |
+
self.scheduler.step()
|
227 |
+
self.optimizer.zero_grad()
|
228 |
+
|
229 |
+
if self.is_main:
|
230 |
+
self.ema_model.update()
|
231 |
+
|
232 |
+
global_step += 1
|
233 |
+
|
234 |
+
if self.accelerator.is_local_main_process:
|
235 |
+
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
236 |
+
|
237 |
+
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
238 |
+
|
239 |
+
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
240 |
+
self.save_checkpoint(global_step)
|
241 |
+
|
242 |
+
if global_step % self.last_per_steps == 0:
|
243 |
+
self.save_checkpoint(global_step, last=True)
|
244 |
+
|
245 |
+
self.accelerator.end_training()
|
model/utils.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import string
|
8 |
+
from tqdm import tqdm
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
import matplotlib
|
12 |
+
matplotlib.use("Agg")
|
13 |
+
import matplotlib.pylab as plt
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.nn.utils.rnn import pad_sequence
|
18 |
+
import torchaudio
|
19 |
+
|
20 |
+
import einx
|
21 |
+
from einops import rearrange, reduce
|
22 |
+
|
23 |
+
import jieba
|
24 |
+
from pypinyin import lazy_pinyin, Style
|
25 |
+
import zhconv
|
26 |
+
from zhon.hanzi import punctuation
|
27 |
+
from jiwer import compute_measures
|
28 |
+
|
29 |
+
from funasr import AutoModel
|
30 |
+
from faster_whisper import WhisperModel
|
31 |
+
|
32 |
+
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
33 |
+
from model.modules import MelSpec
|
34 |
+
|
35 |
+
|
36 |
+
# seed everything
|
37 |
+
|
38 |
+
def seed_everything(seed = 0):
|
39 |
+
random.seed(seed)
|
40 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
torch.cuda.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed_all(seed)
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
torch.backends.cudnn.benchmark = False
|
46 |
+
|
47 |
+
# helpers
|
48 |
+
|
49 |
+
def exists(v):
|
50 |
+
return v is not None
|
51 |
+
|
52 |
+
def default(v, d):
|
53 |
+
return v if exists(v) else d
|
54 |
+
|
55 |
+
# tensor helpers
|
56 |
+
|
57 |
+
def lens_to_mask(
|
58 |
+
t: int['b'],
|
59 |
+
length: int | None = None
|
60 |
+
) -> bool['b n']:
|
61 |
+
|
62 |
+
if not exists(length):
|
63 |
+
length = t.amax()
|
64 |
+
|
65 |
+
seq = torch.arange(length, device = t.device)
|
66 |
+
return einx.less('n, b -> b n', seq, t)
|
67 |
+
|
68 |
+
def mask_from_start_end_indices(
|
69 |
+
seq_len: int['b'],
|
70 |
+
start: int['b'],
|
71 |
+
end: int['b']
|
72 |
+
):
|
73 |
+
max_seq_len = seq_len.max().item()
|
74 |
+
seq = torch.arange(max_seq_len, device = start.device).long()
|
75 |
+
return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
|
76 |
+
|
77 |
+
def mask_from_frac_lengths(
|
78 |
+
seq_len: int['b'],
|
79 |
+
frac_lengths: float['b']
|
80 |
+
):
|
81 |
+
lengths = (frac_lengths * seq_len).long()
|
82 |
+
max_start = seq_len - lengths
|
83 |
+
|
84 |
+
rand = torch.rand_like(frac_lengths)
|
85 |
+
start = (max_start * rand).long().clamp(min = 0)
|
86 |
+
end = start + lengths
|
87 |
+
|
88 |
+
return mask_from_start_end_indices(seq_len, start, end)
|
89 |
+
|
90 |
+
def maybe_masked_mean(
|
91 |
+
t: float['b n d'],
|
92 |
+
mask: bool['b n'] = None
|
93 |
+
) -> float['b d']:
|
94 |
+
|
95 |
+
if not exists(mask):
|
96 |
+
return t.mean(dim = 1)
|
97 |
+
|
98 |
+
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
99 |
+
num = reduce(t, 'b n d -> b d', 'sum')
|
100 |
+
den = reduce(mask.float(), 'b n -> b', 'sum')
|
101 |
+
|
102 |
+
return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
|
103 |
+
|
104 |
+
|
105 |
+
# simple utf-8 tokenizer, since paper went character based
|
106 |
+
def list_str_to_tensor(
|
107 |
+
text: list[str],
|
108 |
+
padding_value = -1
|
109 |
+
) -> int['b nt']:
|
110 |
+
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
|
111 |
+
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
|
112 |
+
return text
|
113 |
+
|
114 |
+
# char tokenizer, based on custom dataset's extracted .txt file
|
115 |
+
def list_str_to_idx(
|
116 |
+
text: list[str] | list[list[str]],
|
117 |
+
vocab_char_map: dict[str, int], # {char: idx}
|
118 |
+
padding_value = -1
|
119 |
+
) -> int['b nt']:
|
120 |
+
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
121 |
+
text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
|
122 |
+
return text
|
123 |
+
|
124 |
+
|
125 |
+
# Get tokenizer
|
126 |
+
|
127 |
+
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
128 |
+
'''
|
129 |
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
130 |
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
131 |
+
- "byte" for utf-8 tokenizer
|
132 |
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
133 |
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
134 |
+
- if use "byte", set to 256 (unicode byte range)
|
135 |
+
'''
|
136 |
+
if tokenizer in ["pinyin", "char"]:
|
137 |
+
with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
|
138 |
+
vocab_char_map = {}
|
139 |
+
for i, char in enumerate(f):
|
140 |
+
vocab_char_map[char[:-1]] = i
|
141 |
+
vocab_size = len(vocab_char_map)
|
142 |
+
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
143 |
+
|
144 |
+
elif tokenizer == "byte":
|
145 |
+
vocab_char_map = None
|
146 |
+
vocab_size = 256
|
147 |
+
|
148 |
+
return vocab_char_map, vocab_size
|
149 |
+
|
150 |
+
|
151 |
+
# convert char to pinyin
|
152 |
+
|
153 |
+
def convert_char_to_pinyin(text_list, polyphone = True):
|
154 |
+
final_text_list = []
|
155 |
+
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
|
156 |
+
for text in text_list:
|
157 |
+
char_list = []
|
158 |
+
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
159 |
+
for seg in jieba.cut(text):
|
160 |
+
seg_byte_len = len(bytes(seg, 'UTF-8'))
|
161 |
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
162 |
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
163 |
+
char_list.append(" ")
|
164 |
+
char_list.extend(seg)
|
165 |
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
|
166 |
+
seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
167 |
+
for c in seg:
|
168 |
+
if c not in "。,、;:?!《》【】—…":
|
169 |
+
char_list.append(" ")
|
170 |
+
char_list.append(c)
|
171 |
+
else: # if mixed chinese characters, alphabets and symbols
|
172 |
+
for c in seg:
|
173 |
+
if ord(c) < 256:
|
174 |
+
char_list.extend(c)
|
175 |
+
else:
|
176 |
+
if c not in "。,、;:?!《》【】—…":
|
177 |
+
char_list.append(" ")
|
178 |
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
179 |
+
else: # if is zh punc
|
180 |
+
char_list.append(c)
|
181 |
+
final_text_list.append(char_list)
|
182 |
+
|
183 |
+
return final_text_list
|
184 |
+
|
185 |
+
|
186 |
+
# save spectrogram
|
187 |
+
def save_spectrogram(spectrogram, path):
|
188 |
+
plt.figure(figsize=(12, 4))
|
189 |
+
plt.imshow(spectrogram, origin='lower', aspect='auto')
|
190 |
+
plt.colorbar()
|
191 |
+
plt.savefig(path)
|
192 |
+
plt.close()
|
193 |
+
|
194 |
+
|
195 |
+
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
196 |
+
def get_seedtts_testset_metainfo(metalst):
|
197 |
+
f = open(metalst); lines = f.readlines(); f.close()
|
198 |
+
metainfo = []
|
199 |
+
for line in lines:
|
200 |
+
if len(line.strip().split('|')) == 5:
|
201 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
202 |
+
elif len(line.strip().split('|')) == 4:
|
203 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
204 |
+
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
205 |
+
if not os.path.isabs(prompt_wav):
|
206 |
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
207 |
+
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
208 |
+
return metainfo
|
209 |
+
|
210 |
+
|
211 |
+
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
212 |
+
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
213 |
+
f = open(metalst); lines = f.readlines(); f.close()
|
214 |
+
metainfo = []
|
215 |
+
for line in lines:
|
216 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
217 |
+
|
218 |
+
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
219 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
220 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
221 |
+
|
222 |
+
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
223 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
224 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
225 |
+
|
226 |
+
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
227 |
+
|
228 |
+
return metainfo
|
229 |
+
|
230 |
+
|
231 |
+
# padded to max length mel batch
|
232 |
+
def padded_mel_batch(ref_mels):
|
233 |
+
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
234 |
+
padded_ref_mels = []
|
235 |
+
for mel in ref_mels:
|
236 |
+
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
|
237 |
+
padded_ref_mels.append(padded_ref_mel)
|
238 |
+
padded_ref_mels = torch.stack(padded_ref_mels)
|
239 |
+
padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
|
240 |
+
return padded_ref_mels
|
241 |
+
|
242 |
+
|
243 |
+
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
244 |
+
|
245 |
+
def get_inference_prompt(
|
246 |
+
metainfo,
|
247 |
+
speed = 1., tokenizer = "pinyin", polyphone = True,
|
248 |
+
target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
|
249 |
+
use_truth_duration = False,
|
250 |
+
infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
|
251 |
+
):
|
252 |
+
prompts_all = []
|
253 |
+
|
254 |
+
min_tokens = min_secs * target_sample_rate // hop_length
|
255 |
+
max_tokens = max_secs * target_sample_rate // hop_length
|
256 |
+
|
257 |
+
batch_accum = [0] * num_buckets
|
258 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
|
259 |
+
([[] for _ in range(num_buckets)] for _ in range(6))
|
260 |
+
|
261 |
+
mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
262 |
+
|
263 |
+
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
264 |
+
|
265 |
+
# Audio
|
266 |
+
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
267 |
+
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
268 |
+
if ref_rms < target_rms:
|
269 |
+
ref_audio = ref_audio * target_rms / ref_rms
|
270 |
+
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
271 |
+
if ref_sr != target_sample_rate:
|
272 |
+
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
273 |
+
ref_audio = resampler(ref_audio)
|
274 |
+
|
275 |
+
# Text
|
276 |
+
text = [prompt_text + gt_text]
|
277 |
+
if tokenizer == "pinyin":
|
278 |
+
text_list = convert_char_to_pinyin(text, polyphone = polyphone)
|
279 |
+
else:
|
280 |
+
text_list = text
|
281 |
+
|
282 |
+
# Duration, mel frame length
|
283 |
+
ref_mel_len = ref_audio.shape[-1] // hop_length
|
284 |
+
if use_truth_duration:
|
285 |
+
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
286 |
+
if gt_sr != target_sample_rate:
|
287 |
+
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
288 |
+
gt_audio = resampler(gt_audio)
|
289 |
+
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
290 |
+
|
291 |
+
# # test vocoder resynthesis
|
292 |
+
# ref_audio = gt_audio
|
293 |
+
else:
|
294 |
+
zh_pause_punc = r"。,、;:?!"
|
295 |
+
ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
|
296 |
+
gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
|
297 |
+
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
298 |
+
|
299 |
+
# to mel spectrogram
|
300 |
+
ref_mel = mel_spectrogram(ref_audio)
|
301 |
+
ref_mel = rearrange(ref_mel, '1 d n -> d n')
|
302 |
+
|
303 |
+
# deal with batch
|
304 |
+
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
305 |
+
assert min_tokens <= total_mel_len <= max_tokens, \
|
306 |
+
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
307 |
+
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
308 |
+
|
309 |
+
utts[bucket_i].append(utt)
|
310 |
+
ref_rms_list[bucket_i].append(ref_rms)
|
311 |
+
ref_mels[bucket_i].append(ref_mel)
|
312 |
+
ref_mel_lens[bucket_i].append(ref_mel_len)
|
313 |
+
total_mel_lens[bucket_i].append(total_mel_len)
|
314 |
+
final_text_list[bucket_i].extend(text_list)
|
315 |
+
|
316 |
+
batch_accum[bucket_i] += total_mel_len
|
317 |
+
|
318 |
+
if batch_accum[bucket_i] >= infer_batch_size:
|
319 |
+
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
320 |
+
prompts_all.append((
|
321 |
+
utts[bucket_i],
|
322 |
+
ref_rms_list[bucket_i],
|
323 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
324 |
+
ref_mel_lens[bucket_i],
|
325 |
+
total_mel_lens[bucket_i],
|
326 |
+
final_text_list[bucket_i]
|
327 |
+
))
|
328 |
+
batch_accum[bucket_i] = 0
|
329 |
+
utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
|
330 |
+
|
331 |
+
# add residual
|
332 |
+
for bucket_i, bucket_frames in enumerate(batch_accum):
|
333 |
+
if bucket_frames > 0:
|
334 |
+
prompts_all.append((
|
335 |
+
utts[bucket_i],
|
336 |
+
ref_rms_list[bucket_i],
|
337 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
338 |
+
ref_mel_lens[bucket_i],
|
339 |
+
total_mel_lens[bucket_i],
|
340 |
+
final_text_list[bucket_i]
|
341 |
+
))
|
342 |
+
# not only leave easy work for last workers
|
343 |
+
random.seed(666)
|
344 |
+
random.shuffle(prompts_all)
|
345 |
+
|
346 |
+
return prompts_all
|
347 |
+
|
348 |
+
|
349 |
+
# get wav_res_ref_text of seed-tts test metalst
|
350 |
+
# https://github.com/BytedanceSpeech/seed-tts-eval
|
351 |
+
|
352 |
+
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
353 |
+
f = open(metalst)
|
354 |
+
lines = f.readlines()
|
355 |
+
f.close()
|
356 |
+
|
357 |
+
test_set_ = []
|
358 |
+
for line in tqdm(lines):
|
359 |
+
if len(line.strip().split('|')) == 5:
|
360 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
361 |
+
elif len(line.strip().split('|')) == 4:
|
362 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
363 |
+
|
364 |
+
if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
|
365 |
+
continue
|
366 |
+
gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
|
367 |
+
if not os.path.isabs(prompt_wav):
|
368 |
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
369 |
+
|
370 |
+
test_set_.append((gen_wav, prompt_wav, gt_text))
|
371 |
+
|
372 |
+
num_jobs = len(gpus)
|
373 |
+
if num_jobs == 1:
|
374 |
+
return [(gpus[0], test_set_)]
|
375 |
+
|
376 |
+
wav_per_job = len(test_set_) // num_jobs + 1
|
377 |
+
test_set = []
|
378 |
+
for i in range(num_jobs):
|
379 |
+
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
380 |
+
|
381 |
+
return test_set
|
382 |
+
|
383 |
+
|
384 |
+
# get librispeech test-clean cross sentence test
|
385 |
+
|
386 |
+
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
|
387 |
+
f = open(metalst)
|
388 |
+
lines = f.readlines()
|
389 |
+
f.close()
|
390 |
+
|
391 |
+
test_set_ = []
|
392 |
+
for line in tqdm(lines):
|
393 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
394 |
+
|
395 |
+
if eval_ground_truth:
|
396 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
397 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
398 |
+
else:
|
399 |
+
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
|
400 |
+
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
401 |
+
gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
|
402 |
+
|
403 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
404 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
405 |
+
|
406 |
+
test_set_.append((gen_wav, ref_wav, gen_txt))
|
407 |
+
|
408 |
+
num_jobs = len(gpus)
|
409 |
+
if num_jobs == 1:
|
410 |
+
return [(gpus[0], test_set_)]
|
411 |
+
|
412 |
+
wav_per_job = len(test_set_) // num_jobs + 1
|
413 |
+
test_set = []
|
414 |
+
for i in range(num_jobs):
|
415 |
+
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
416 |
+
|
417 |
+
return test_set
|
418 |
+
|
419 |
+
|
420 |
+
# load asr model
|
421 |
+
|
422 |
+
def load_asr_model(lang, ckpt_dir = ""):
|
423 |
+
if lang == "zh":
|
424 |
+
model = AutoModel(
|
425 |
+
model = os.path.join(ckpt_dir, "paraformer-zh"),
|
426 |
+
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
427 |
+
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
428 |
+
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
429 |
+
disable_update=True,
|
430 |
+
) # following seed-tts setting
|
431 |
+
elif lang == "en":
|
432 |
+
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
433 |
+
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
434 |
+
return model
|
435 |
+
|
436 |
+
|
437 |
+
# WER Evaluation, the way Seed-TTS does
|
438 |
+
|
439 |
+
def run_asr_wer(args):
|
440 |
+
rank, lang, test_set, ckpt_dir = args
|
441 |
+
|
442 |
+
if lang == "zh":
|
443 |
+
torch.cuda.set_device(rank)
|
444 |
+
elif lang == "en":
|
445 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
446 |
+
else:
|
447 |
+
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
|
448 |
+
|
449 |
+
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
450 |
+
|
451 |
+
punctuation_all = punctuation + string.punctuation
|
452 |
+
wers = []
|
453 |
+
|
454 |
+
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
455 |
+
if lang == "zh":
|
456 |
+
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
457 |
+
hypo = res[0]["text"]
|
458 |
+
hypo = zhconv.convert(hypo, 'zh-cn')
|
459 |
+
elif lang == "en":
|
460 |
+
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
461 |
+
hypo = ''
|
462 |
+
for segment in segments:
|
463 |
+
hypo = hypo + ' ' + segment.text
|
464 |
+
|
465 |
+
# raw_truth = truth
|
466 |
+
# raw_hypo = hypo
|
467 |
+
|
468 |
+
for x in punctuation_all:
|
469 |
+
truth = truth.replace(x, '')
|
470 |
+
hypo = hypo.replace(x, '')
|
471 |
+
|
472 |
+
truth = truth.replace(' ', ' ')
|
473 |
+
hypo = hypo.replace(' ', ' ')
|
474 |
+
|
475 |
+
if lang == "zh":
|
476 |
+
truth = " ".join([x for x in truth])
|
477 |
+
hypo = " ".join([x for x in hypo])
|
478 |
+
elif lang == "en":
|
479 |
+
truth = truth.lower()
|
480 |
+
hypo = hypo.lower()
|
481 |
+
|
482 |
+
measures = compute_measures(truth, hypo)
|
483 |
+
wer = measures["wer"]
|
484 |
+
|
485 |
+
# ref_list = truth.split(" ")
|
486 |
+
# subs = measures["substitutions"] / len(ref_list)
|
487 |
+
# dele = measures["deletions"] / len(ref_list)
|
488 |
+
# inse = measures["insertions"] / len(ref_list)
|
489 |
+
|
490 |
+
wers.append(wer)
|
491 |
+
|
492 |
+
return wers
|
493 |
+
|
494 |
+
|
495 |
+
# SIM Evaluation
|
496 |
+
|
497 |
+
def run_sim(args):
|
498 |
+
rank, test_set, ckpt_dir = args
|
499 |
+
device = f"cuda:{rank}"
|
500 |
+
|
501 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
|
502 |
+
state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
|
503 |
+
model.load_state_dict(state_dict['model'], strict=False)
|
504 |
+
|
505 |
+
use_gpu=True if torch.cuda.is_available() else False
|
506 |
+
if use_gpu:
|
507 |
+
model = model.cuda(device)
|
508 |
+
model.eval()
|
509 |
+
|
510 |
+
sim_list = []
|
511 |
+
for wav1, wav2, truth in tqdm(test_set):
|
512 |
+
|
513 |
+
wav1, sr1 = torchaudio.load(wav1)
|
514 |
+
wav2, sr2 = torchaudio.load(wav2)
|
515 |
+
|
516 |
+
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
517 |
+
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
518 |
+
wav1 = resample1(wav1)
|
519 |
+
wav2 = resample2(wav2)
|
520 |
+
|
521 |
+
if use_gpu:
|
522 |
+
wav1 = wav1.cuda(device)
|
523 |
+
wav2 = wav2.cuda(device)
|
524 |
+
with torch.no_grad():
|
525 |
+
emb1 = model(wav1)
|
526 |
+
emb2 = model(wav2)
|
527 |
+
|
528 |
+
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
529 |
+
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
530 |
+
sim_list.append(sim)
|
531 |
+
|
532 |
+
return sim_list
|
533 |
+
|
534 |
+
|
535 |
+
# filter func for dirty data with many repetitions
|
536 |
+
|
537 |
+
def repetition_found(text, length = 2, tolerance = 10):
|
538 |
+
pattern_count = defaultdict(int)
|
539 |
+
for i in range(len(text) - length + 1):
|
540 |
+
pattern = text[i:i + length]
|
541 |
+
pattern_count[pattern] += 1
|
542 |
+
for pattern, count in pattern_count.items():
|
543 |
+
if count > tolerance:
|
544 |
+
return True
|
545 |
+
return False
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate>=0.33.0
|
2 |
+
datasets
|
3 |
+
einops>=0.8.0
|
4 |
+
einx>=0.3.0
|
5 |
+
ema_pytorch>=0.5.2
|
6 |
+
faster_whisper
|
7 |
+
funasr
|
8 |
+
jieba
|
9 |
+
jiwer
|
10 |
+
librosa
|
11 |
+
matplotlib
|
12 |
+
pypinyin
|
13 |
+
torch>=2.0
|
14 |
+
torchaudio>=2.3.0
|
15 |
+
torchdiffeq
|
16 |
+
tqdm>=4.65.0
|
17 |
+
transformers
|
18 |
+
vocos
|
19 |
+
wandb
|
20 |
+
x_transformers>=1.31.14
|
21 |
+
zhconv
|
22 |
+
zhon
|
23 |
+
cached_path
|
24 |
+
pydub
|
scripts/count_max_epoch.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''ADAPTIVE BATCH SIZE'''
|
2 |
+
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
|
3 |
+
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
|
4 |
+
|
5 |
+
# data
|
6 |
+
total_hours = 95282
|
7 |
+
mel_hop_length = 256
|
8 |
+
mel_sampling_rate = 24000
|
9 |
+
|
10 |
+
# target
|
11 |
+
wanted_max_updates = 1000000
|
12 |
+
|
13 |
+
# train params
|
14 |
+
gpus = 8
|
15 |
+
frames_per_gpu = 38400 # 8 * 38400 = 307200
|
16 |
+
grad_accum = 1
|
17 |
+
|
18 |
+
# intermediate
|
19 |
+
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
20 |
+
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
21 |
+
updates_per_epoch = total_hours / mini_batch_hours
|
22 |
+
steps_per_epoch = updates_per_epoch * grad_accum
|
23 |
+
|
24 |
+
# result
|
25 |
+
epochs = wanted_max_updates / updates_per_epoch
|
26 |
+
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
27 |
+
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
28 |
+
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
29 |
+
|
30 |
+
# others
|
31 |
+
print(f"total {total_hours:.0f} hours")
|
32 |
+
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
|
scripts/count_params_gflops.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.getcwd())
|
3 |
+
|
4 |
+
from model import M2_TTS, UNetT, DiT, MMDiT
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import thop
|
8 |
+
|
9 |
+
|
10 |
+
''' ~155M '''
|
11 |
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
12 |
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
13 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
14 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
15 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
16 |
+
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
17 |
+
|
18 |
+
''' ~335M '''
|
19 |
+
# FLOPs: 622.1 G, Params: 333.2 M
|
20 |
+
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
21 |
+
# FLOPs: 363.4 G, Params: 335.8 M
|
22 |
+
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
23 |
+
|
24 |
+
|
25 |
+
model = M2_TTS(transformer=transformer)
|
26 |
+
target_sample_rate = 24000
|
27 |
+
n_mel_channels = 100
|
28 |
+
hop_length = 256
|
29 |
+
duration = 20
|
30 |
+
frame_length = int(duration * target_sample_rate / hop_length)
|
31 |
+
text_length = 150
|
32 |
+
|
33 |
+
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
|
34 |
+
print(f"FLOPs: {flops / 1e9} G")
|
35 |
+
print(f"Params: {params / 1e6} M")
|
scripts/eval_librispeech_test_clean.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
+
|
3 |
+
import sys, os
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import multiprocessing as mp
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from model.utils import (
|
10 |
+
get_librispeech_test,
|
11 |
+
run_asr_wer,
|
12 |
+
run_sim,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
eval_task = "wer" # sim | wer
|
17 |
+
lang = "en"
|
18 |
+
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
19 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
20 |
+
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
21 |
+
|
22 |
+
gpus = [0,1,2,3,4,5,6,7]
|
23 |
+
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
24 |
+
|
25 |
+
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
26 |
+
## leading to a low similarity for the ground truth in some cases.
|
27 |
+
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
28 |
+
|
29 |
+
local = False
|
30 |
+
if local: # use local custom checkpoint dir
|
31 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
32 |
+
else:
|
33 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
34 |
+
|
35 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
36 |
+
|
37 |
+
|
38 |
+
# --------------------------- WER ---------------------------
|
39 |
+
|
40 |
+
if eval_task == "wer":
|
41 |
+
wers = []
|
42 |
+
|
43 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
44 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
45 |
+
results = pool.map(run_asr_wer, args)
|
46 |
+
for wers_ in results:
|
47 |
+
wers.extend(wers_)
|
48 |
+
|
49 |
+
wer = round(np.mean(wers)*100, 3)
|
50 |
+
print(f"\nTotal {len(wers)} samples")
|
51 |
+
print(f"WER : {wer}%")
|
52 |
+
|
53 |
+
|
54 |
+
# --------------------------- SIM ---------------------------
|
55 |
+
|
56 |
+
if eval_task == "sim":
|
57 |
+
sim_list = []
|
58 |
+
|
59 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
60 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
61 |
+
results = pool.map(run_sim, args)
|
62 |
+
for sim_ in results:
|
63 |
+
sim_list.extend(sim_)
|
64 |
+
|
65 |
+
sim = round(sum(sim_list)/len(sim_list), 3)
|
66 |
+
print(f"\nTotal {len(sim_list)} samples")
|
67 |
+
print(f"SIM : {sim}")
|
scripts/eval_seedtts_testset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Seed-TTS testset
|
2 |
+
|
3 |
+
import sys, os
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import multiprocessing as mp
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from model.utils import (
|
10 |
+
get_seed_tts_test,
|
11 |
+
run_asr_wer,
|
12 |
+
run_sim,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
eval_task = "wer" # sim | wer
|
17 |
+
lang = "zh" # zh | en
|
18 |
+
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
19 |
+
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
20 |
+
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
|
21 |
+
|
22 |
+
|
23 |
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
24 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
25 |
+
gpus = [0,1,2,3,4,5,6,7]
|
26 |
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
27 |
+
|
28 |
+
local = False
|
29 |
+
if local: # use local custom checkpoint dir
|
30 |
+
if lang == "zh":
|
31 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
32 |
+
elif lang == "en":
|
33 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
34 |
+
else:
|
35 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
36 |
+
|
37 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
38 |
+
|
39 |
+
|
40 |
+
# --------------------------- WER ---------------------------
|
41 |
+
|
42 |
+
if eval_task == "wer":
|
43 |
+
wers = []
|
44 |
+
|
45 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
46 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
47 |
+
results = pool.map(run_asr_wer, args)
|
48 |
+
for wers_ in results:
|
49 |
+
wers.extend(wers_)
|
50 |
+
|
51 |
+
wer = round(np.mean(wers)*100, 3)
|
52 |
+
print(f"\nTotal {len(wers)} samples")
|
53 |
+
print(f"WER : {wer}%")
|
54 |
+
|
55 |
+
|
56 |
+
# --------------------------- SIM ---------------------------
|
57 |
+
|
58 |
+
if eval_task == "sim":
|
59 |
+
sim_list = []
|
60 |
+
|
61 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
+
results = pool.map(run_sim, args)
|
64 |
+
for sim_ in results:
|
65 |
+
sim_list.extend(sim_)
|
66 |
+
|
67 |
+
sim = round(sum(sim_list)/len(sim_list), 3)
|
68 |
+
print(f"\nTotal {len(sim_list)} samples")
|
69 |
+
print(f"SIM : {sim}")
|
scripts/prepare_emilia.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
|
2 |
+
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
|
3 |
+
|
4 |
+
# generate audio text map for Emilia ZH & EN
|
5 |
+
# evaluate for vocab size
|
6 |
+
|
7 |
+
import sys, os
|
8 |
+
sys.path.append(os.getcwd())
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
import json
|
12 |
+
from tqdm import tqdm
|
13 |
+
from concurrent.futures import ProcessPoolExecutor
|
14 |
+
|
15 |
+
from datasets import Dataset
|
16 |
+
from datasets.arrow_writer import ArrowWriter
|
17 |
+
|
18 |
+
from model.utils import (
|
19 |
+
repetition_found,
|
20 |
+
convert_char_to_pinyin,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
|
25 |
+
zh_filters = ["い", "て"]
|
26 |
+
# seems synthesized audios, or heavily code-switched
|
27 |
+
out_en = {
|
28 |
+
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
|
29 |
+
|
30 |
+
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
|
31 |
+
}
|
32 |
+
en_filters = ["ا", "い", "て"]
|
33 |
+
|
34 |
+
|
35 |
+
def deal_with_audio_dir(audio_dir):
|
36 |
+
audio_jsonl = audio_dir.with_suffix(".jsonl")
|
37 |
+
sub_result, durations = [], []
|
38 |
+
vocab_set = set()
|
39 |
+
bad_case_zh = 0
|
40 |
+
bad_case_en = 0
|
41 |
+
with open(audio_jsonl, "r") as f:
|
42 |
+
lines = f.readlines()
|
43 |
+
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
44 |
+
obj = json.loads(line)
|
45 |
+
text = obj["text"]
|
46 |
+
if obj['language'] == "zh":
|
47 |
+
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
48 |
+
bad_case_zh += 1
|
49 |
+
continue
|
50 |
+
else:
|
51 |
+
text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
|
52 |
+
if obj['language'] == "en":
|
53 |
+
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
|
54 |
+
bad_case_en += 1
|
55 |
+
continue
|
56 |
+
if tokenizer == "pinyin":
|
57 |
+
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
|
58 |
+
duration = obj["duration"]
|
59 |
+
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
60 |
+
durations.append(duration)
|
61 |
+
vocab_set.update(list(text))
|
62 |
+
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
assert tokenizer in ["pinyin", "char"]
|
67 |
+
result = []
|
68 |
+
duration_list = []
|
69 |
+
text_vocab_set = set()
|
70 |
+
total_bad_case_zh = 0
|
71 |
+
total_bad_case_en = 0
|
72 |
+
|
73 |
+
# process raw data
|
74 |
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
75 |
+
futures = []
|
76 |
+
for lang in langs:
|
77 |
+
dataset_path = Path(os.path.join(dataset_dir, lang))
|
78 |
+
[
|
79 |
+
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
80 |
+
for audio_dir in dataset_path.iterdir()
|
81 |
+
if audio_dir.is_dir()
|
82 |
+
]
|
83 |
+
for futures in tqdm(futures, total=len(futures)):
|
84 |
+
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
|
85 |
+
result.extend(sub_result)
|
86 |
+
duration_list.extend(durations)
|
87 |
+
text_vocab_set.update(vocab_set)
|
88 |
+
total_bad_case_zh += bad_case_zh
|
89 |
+
total_bad_case_en += bad_case_en
|
90 |
+
executor.shutdown()
|
91 |
+
|
92 |
+
# save preprocessed dataset to disk
|
93 |
+
if not os.path.exists(f"data/{dataset_name}"):
|
94 |
+
os.makedirs(f"data/{dataset_name}")
|
95 |
+
print(f"\nSaving to data/{dataset_name} ...")
|
96 |
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
97 |
+
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
98 |
+
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
99 |
+
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
100 |
+
writer.write(line)
|
101 |
+
|
102 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
103 |
+
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
|
104 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
105 |
+
|
106 |
+
# vocab map, i.e. tokenizer
|
107 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
108 |
+
# if tokenizer == "pinyin":
|
109 |
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
110 |
+
with open(f"data/{dataset_name}/vocab.txt", "w") as f:
|
111 |
+
for vocab in sorted(text_vocab_set):
|
112 |
+
f.write(vocab + "\n")
|
113 |
+
|
114 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
115 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
116 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
117 |
+
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
|
118 |
+
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
|
123 |
+
max_workers = 32
|
124 |
+
|
125 |
+
tokenizer = "pinyin" # "pinyin" | "char"
|
126 |
+
polyphone = True
|
127 |
+
|
128 |
+
langs = ["ZH", "EN"]
|
129 |
+
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
130 |
+
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
131 |
+
print(f"\nPrepare for {dataset_name}\n")
|
132 |
+
|
133 |
+
main()
|
134 |
+
|
135 |
+
# Emilia ZH & EN
|
136 |
+
# samples count 37837916 (after removal)
|
137 |
+
# pinyin vocab size 2543 (polyphone)
|
138 |
+
# total duration 95281.87 (hours)
|
139 |
+
# bad zh asr cnt 230435 (samples)
|
140 |
+
# bad eh asr cnt 37217 (samples)
|
141 |
+
|
142 |
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
143 |
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|
scripts/prepare_wenetspeech4tts.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generate audio text map for WenetSpeech4TTS
|
2 |
+
# evaluate for vocab size
|
3 |
+
|
4 |
+
import sys, os
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
import json
|
8 |
+
from tqdm import tqdm
|
9 |
+
from concurrent.futures import ProcessPoolExecutor
|
10 |
+
|
11 |
+
import torchaudio
|
12 |
+
from datasets import Dataset
|
13 |
+
|
14 |
+
from model.utils import convert_char_to_pinyin
|
15 |
+
|
16 |
+
|
17 |
+
def deal_with_sub_path_files(dataset_path, sub_path):
|
18 |
+
print(f"Dealing with: {sub_path}")
|
19 |
+
|
20 |
+
text_dir = os.path.join(dataset_path, sub_path, "txts")
|
21 |
+
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
|
22 |
+
text_files = os.listdir(text_dir)
|
23 |
+
|
24 |
+
audio_paths, texts, durations = [], [], []
|
25 |
+
for text_file in tqdm(text_files):
|
26 |
+
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
|
27 |
+
first_line = file.readline().split("\t")
|
28 |
+
audio_nm = first_line[0]
|
29 |
+
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
30 |
+
text = first_line[1].strip()
|
31 |
+
|
32 |
+
audio_paths.append(audio_path)
|
33 |
+
|
34 |
+
if tokenizer == "pinyin":
|
35 |
+
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
|
36 |
+
elif tokenizer == "char":
|
37 |
+
texts.append(text)
|
38 |
+
|
39 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
40 |
+
durations.append(audio.shape[-1] / sample_rate)
|
41 |
+
|
42 |
+
return audio_paths, texts, durations
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
assert tokenizer in ["pinyin", "char"]
|
47 |
+
|
48 |
+
audio_path_list, text_list, duration_list = [], [], []
|
49 |
+
|
50 |
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
51 |
+
futures = []
|
52 |
+
for dataset_path in dataset_paths:
|
53 |
+
sub_items = os.listdir(dataset_path)
|
54 |
+
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
|
55 |
+
for sub_path in sub_paths:
|
56 |
+
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
|
57 |
+
for future in tqdm(futures, total=len(futures)):
|
58 |
+
audio_paths, texts, durations = future.result()
|
59 |
+
audio_path_list.extend(audio_paths)
|
60 |
+
text_list.extend(texts)
|
61 |
+
duration_list.extend(durations)
|
62 |
+
executor.shutdown()
|
63 |
+
|
64 |
+
if not os.path.exists("data"):
|
65 |
+
os.makedirs("data")
|
66 |
+
|
67 |
+
print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
|
68 |
+
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
69 |
+
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
70 |
+
|
71 |
+
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
|
72 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
73 |
+
|
74 |
+
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
75 |
+
text_vocab_set = set()
|
76 |
+
for text in tqdm(text_list):
|
77 |
+
text_vocab_set.update(list(text))
|
78 |
+
|
79 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
80 |
+
if tokenizer == "pinyin":
|
81 |
+
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
82 |
+
|
83 |
+
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
|
84 |
+
for vocab in sorted(text_vocab_set):
|
85 |
+
f.write(vocab + "\n")
|
86 |
+
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
87 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
|
92 |
+
max_workers = 32
|
93 |
+
|
94 |
+
tokenizer = "pinyin" # "pinyin" | "char"
|
95 |
+
polyphone = True
|
96 |
+
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
97 |
+
|
98 |
+
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
|
99 |
+
dataset_paths = [
|
100 |
+
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
101 |
+
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
102 |
+
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
103 |
+
][-dataset_choice:]
|
104 |
+
print(f"\nChoose Dataset: {dataset_name}\n")
|
105 |
+
|
106 |
+
main()
|
107 |
+
|
108 |
+
# Results (if adding alphabets with accents and symbols):
|
109 |
+
# WenetSpeech4TTS Basic Standard Premium
|
110 |
+
# samples count 3932473 1941220 407494
|
111 |
+
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
112 |
+
# - - 1459 (polyphone)
|
113 |
+
# char vocab size 5264 5219 5042
|
114 |
+
|
115 |
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
116 |
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|
test_infer_batch.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from einops import rearrange
|
11 |
+
from ema_pytorch import EMA
|
12 |
+
from vocos import Vocos
|
13 |
+
|
14 |
+
from model import CFM, UNetT, DiT
|
15 |
+
from model.utils import (
|
16 |
+
get_tokenizer,
|
17 |
+
get_seedtts_testset_metainfo,
|
18 |
+
get_librispeech_test_clean_metainfo,
|
19 |
+
get_inference_prompt,
|
20 |
+
)
|
21 |
+
|
22 |
+
accelerator = Accelerator()
|
23 |
+
device = f"cuda:{accelerator.process_index}"
|
24 |
+
|
25 |
+
|
26 |
+
# --------------------- Dataset Settings -------------------- #
|
27 |
+
|
28 |
+
target_sample_rate = 24000
|
29 |
+
n_mel_channels = 100
|
30 |
+
hop_length = 256
|
31 |
+
target_rms = 0.1
|
32 |
+
|
33 |
+
tokenizer = "pinyin"
|
34 |
+
|
35 |
+
|
36 |
+
# ---------------------- infer setting ---------------------- #
|
37 |
+
|
38 |
+
parser = argparse.ArgumentParser(description="batch inference")
|
39 |
+
|
40 |
+
parser.add_argument('-s', '--seed', default=None, type=int)
|
41 |
+
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
|
42 |
+
parser.add_argument('-n', '--expname', required=True)
|
43 |
+
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
|
44 |
+
|
45 |
+
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
|
46 |
+
parser.add_argument('-o', '--odemethod', default="euler")
|
47 |
+
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
|
48 |
+
|
49 |
+
parser.add_argument('-t', '--testset', required=True)
|
50 |
+
|
51 |
+
args = parser.parse_args()
|
52 |
+
|
53 |
+
|
54 |
+
seed = args.seed
|
55 |
+
dataset_name = args.dataset
|
56 |
+
exp_name = args.expname
|
57 |
+
ckpt_step = args.ckptstep
|
58 |
+
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
59 |
+
|
60 |
+
nfe_step = args.nfestep
|
61 |
+
ode_method = args.odemethod
|
62 |
+
sway_sampling_coef = args.swaysampling
|
63 |
+
|
64 |
+
testset = args.testset
|
65 |
+
|
66 |
+
|
67 |
+
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
68 |
+
cfg_strength = 2.
|
69 |
+
speed = 1.
|
70 |
+
use_truth_duration = False
|
71 |
+
no_ref_audio = False
|
72 |
+
|
73 |
+
|
74 |
+
if exp_name == "F5TTS_Base":
|
75 |
+
model_cls = DiT
|
76 |
+
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
77 |
+
|
78 |
+
elif exp_name == "E2TTS_Base":
|
79 |
+
model_cls = UNetT
|
80 |
+
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
81 |
+
|
82 |
+
|
83 |
+
if testset == "ls_pc_test_clean":
|
84 |
+
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
85 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
86 |
+
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
87 |
+
|
88 |
+
elif testset == "seedtts_test_zh":
|
89 |
+
metalst = "data/seedtts_testset/zh/meta.lst"
|
90 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
91 |
+
|
92 |
+
elif testset == "seedtts_test_en":
|
93 |
+
metalst = "data/seedtts_testset/en/meta.lst"
|
94 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
95 |
+
|
96 |
+
|
97 |
+
# path to save genereted wavs
|
98 |
+
if seed is None: seed = random.randint(-10000, 10000)
|
99 |
+
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
100 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
|
101 |
+
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
|
102 |
+
f"_cfg{cfg_strength}_speed{speed}" \
|
103 |
+
f"{'_gt-dur' if use_truth_duration else ''}" \
|
104 |
+
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
105 |
+
|
106 |
+
|
107 |
+
# -------------------------------------------------#
|
108 |
+
|
109 |
+
use_ema = True
|
110 |
+
|
111 |
+
prompts_all = get_inference_prompt(
|
112 |
+
metainfo,
|
113 |
+
speed = speed,
|
114 |
+
tokenizer = tokenizer,
|
115 |
+
target_sample_rate = target_sample_rate,
|
116 |
+
n_mel_channels = n_mel_channels,
|
117 |
+
hop_length = hop_length,
|
118 |
+
target_rms = target_rms,
|
119 |
+
use_truth_duration = use_truth_duration,
|
120 |
+
infer_batch_size = infer_batch_size,
|
121 |
+
)
|
122 |
+
|
123 |
+
# Vocoder model
|
124 |
+
local = False
|
125 |
+
if local:
|
126 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
127 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
128 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
129 |
+
vocos.load_state_dict(state_dict)
|
130 |
+
vocos.eval()
|
131 |
+
else:
|
132 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
133 |
+
|
134 |
+
# Tokenizer
|
135 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
136 |
+
|
137 |
+
# Model
|
138 |
+
model = CFM(
|
139 |
+
transformer = model_cls(
|
140 |
+
**model_cfg,
|
141 |
+
text_num_embeds = vocab_size,
|
142 |
+
mel_dim = n_mel_channels
|
143 |
+
),
|
144 |
+
mel_spec_kwargs = dict(
|
145 |
+
target_sample_rate = target_sample_rate,
|
146 |
+
n_mel_channels = n_mel_channels,
|
147 |
+
hop_length = hop_length,
|
148 |
+
),
|
149 |
+
odeint_kwargs = dict(
|
150 |
+
method = ode_method,
|
151 |
+
),
|
152 |
+
vocab_char_map = vocab_char_map,
|
153 |
+
).to(device)
|
154 |
+
|
155 |
+
if use_ema == True:
|
156 |
+
ema_model = EMA(model, include_online_model = False).to(device)
|
157 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
158 |
+
ema_model.copy_params_from_ema_to_model()
|
159 |
+
else:
|
160 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
161 |
+
|
162 |
+
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
163 |
+
os.makedirs(output_dir)
|
164 |
+
|
165 |
+
# start batch inference
|
166 |
+
accelerator.wait_for_everyone()
|
167 |
+
start = time.time()
|
168 |
+
|
169 |
+
with accelerator.split_between_processes(prompts_all) as prompts:
|
170 |
+
|
171 |
+
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
172 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
173 |
+
ref_mels = ref_mels.to(device)
|
174 |
+
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
|
175 |
+
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
|
176 |
+
|
177 |
+
# Inference
|
178 |
+
with torch.inference_mode():
|
179 |
+
generated, _ = model.sample(
|
180 |
+
cond = ref_mels,
|
181 |
+
text = final_text_list,
|
182 |
+
duration = total_mel_lens,
|
183 |
+
lens = ref_mel_lens,
|
184 |
+
steps = nfe_step,
|
185 |
+
cfg_strength = cfg_strength,
|
186 |
+
sway_sampling_coef = sway_sampling_coef,
|
187 |
+
no_ref_audio = no_ref_audio,
|
188 |
+
seed = seed,
|
189 |
+
)
|
190 |
+
# Final result
|
191 |
+
for i, gen in enumerate(generated):
|
192 |
+
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
193 |
+
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
|
194 |
+
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
195 |
+
if ref_rms_list[i] < target_rms:
|
196 |
+
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
197 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
198 |
+
|
199 |
+
accelerator.wait_for_everyone()
|
200 |
+
if accelerator.is_main_process:
|
201 |
+
timediff = time.time() - start
|
202 |
+
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
test_infer_batch.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
+
|
8 |
+
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
+
|
13 |
+
# etc.
|
test_infer_single.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from einops import rearrange
|
7 |
+
from ema_pytorch import EMA
|
8 |
+
from vocos import Vocos
|
9 |
+
|
10 |
+
from model import CFM, UNetT, DiT, MMDiT
|
11 |
+
from model.utils import (
|
12 |
+
get_tokenizer,
|
13 |
+
convert_char_to_pinyin,
|
14 |
+
save_spectrogram,
|
15 |
+
)
|
16 |
+
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
|
20 |
+
# --------------------- Dataset Settings -------------------- #
|
21 |
+
|
22 |
+
target_sample_rate = 24000
|
23 |
+
n_mel_channels = 100
|
24 |
+
hop_length = 256
|
25 |
+
target_rms = 0.1
|
26 |
+
|
27 |
+
tokenizer = "pinyin"
|
28 |
+
dataset_name = "Emilia_ZH_EN"
|
29 |
+
|
30 |
+
|
31 |
+
# ---------------------- infer setting ---------------------- #
|
32 |
+
|
33 |
+
seed = None # int | None
|
34 |
+
|
35 |
+
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
36 |
+
ckpt_step = 1200000
|
37 |
+
|
38 |
+
nfe_step = 32 # 16, 32
|
39 |
+
cfg_strength = 2.
|
40 |
+
ode_method = 'euler' # euler | midpoint
|
41 |
+
sway_sampling_coef = -1.
|
42 |
+
speed = 1.
|
43 |
+
fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
|
44 |
+
|
45 |
+
if exp_name == "F5TTS_Base":
|
46 |
+
model_cls = DiT
|
47 |
+
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
48 |
+
|
49 |
+
elif exp_name == "E2TTS_Base":
|
50 |
+
model_cls = UNetT
|
51 |
+
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
52 |
+
|
53 |
+
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
54 |
+
output_dir = "tests"
|
55 |
+
|
56 |
+
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
57 |
+
ref_text = "Some call me nature, others call me mother nature."
|
58 |
+
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
59 |
+
|
60 |
+
# ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
|
61 |
+
# ref_text = "对,这就是我,万人敬仰的太乙真人。"
|
62 |
+
# gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
|
63 |
+
|
64 |
+
|
65 |
+
# -------------------------------------------------#
|
66 |
+
|
67 |
+
use_ema = True
|
68 |
+
|
69 |
+
if not os.path.exists(output_dir):
|
70 |
+
os.makedirs(output_dir)
|
71 |
+
|
72 |
+
# Vocoder model
|
73 |
+
local = False
|
74 |
+
if local:
|
75 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
76 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
77 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
78 |
+
vocos.load_state_dict(state_dict)
|
79 |
+
vocos.eval()
|
80 |
+
else:
|
81 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
82 |
+
|
83 |
+
# Tokenizer
|
84 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
85 |
+
|
86 |
+
# Model
|
87 |
+
model = CFM(
|
88 |
+
transformer = model_cls(
|
89 |
+
**model_cfg,
|
90 |
+
text_num_embeds = vocab_size,
|
91 |
+
mel_dim = n_mel_channels
|
92 |
+
),
|
93 |
+
mel_spec_kwargs = dict(
|
94 |
+
target_sample_rate = target_sample_rate,
|
95 |
+
n_mel_channels = n_mel_channels,
|
96 |
+
hop_length = hop_length,
|
97 |
+
),
|
98 |
+
odeint_kwargs = dict(
|
99 |
+
method = ode_method,
|
100 |
+
),
|
101 |
+
vocab_char_map = vocab_char_map,
|
102 |
+
).to(device)
|
103 |
+
|
104 |
+
if use_ema == True:
|
105 |
+
ema_model = EMA(model, include_online_model = False).to(device)
|
106 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
107 |
+
ema_model.copy_params_from_ema_to_model()
|
108 |
+
else:
|
109 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
110 |
+
|
111 |
+
# Audio
|
112 |
+
audio, sr = torchaudio.load(ref_audio)
|
113 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
114 |
+
if rms < target_rms:
|
115 |
+
audio = audio * target_rms / rms
|
116 |
+
if sr != target_sample_rate:
|
117 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
118 |
+
audio = resampler(audio)
|
119 |
+
audio = audio.to(device)
|
120 |
+
|
121 |
+
# Text
|
122 |
+
text_list = [ref_text + gen_text]
|
123 |
+
if tokenizer == "pinyin":
|
124 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
125 |
+
else:
|
126 |
+
final_text_list = [text_list]
|
127 |
+
print(f"text : {text_list}")
|
128 |
+
print(f"pinyin: {final_text_list}")
|
129 |
+
|
130 |
+
# Duration
|
131 |
+
ref_audio_len = audio.shape[-1] // hop_length
|
132 |
+
if fix_duration is not None:
|
133 |
+
duration = int(fix_duration * target_sample_rate / hop_length)
|
134 |
+
else: # simple linear scale calcul
|
135 |
+
zh_pause_punc = r"。,、;:?!"
|
136 |
+
ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
|
137 |
+
gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
|
138 |
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
139 |
+
|
140 |
+
# Inference
|
141 |
+
with torch.inference_mode():
|
142 |
+
generated, trajectory = model.sample(
|
143 |
+
cond = audio,
|
144 |
+
text = final_text_list,
|
145 |
+
duration = duration,
|
146 |
+
steps = nfe_step,
|
147 |
+
cfg_strength = cfg_strength,
|
148 |
+
sway_sampling_coef = sway_sampling_coef,
|
149 |
+
seed = seed,
|
150 |
+
)
|
151 |
+
print(f"Generated mel: {generated.shape}")
|
152 |
+
|
153 |
+
# Final result
|
154 |
+
generated = generated[:, ref_audio_len:, :]
|
155 |
+
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
156 |
+
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
157 |
+
if rms < target_rms:
|
158 |
+
generated_wave = generated_wave * rms / target_rms
|
159 |
+
|
160 |
+
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
|
161 |
+
torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
|
162 |
+
print(f"Generated wav: {generated_wave.shape}")
|
test_train.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model import CFM, UNetT, DiT, MMDiT, Trainer
|
2 |
+
from model.utils import get_tokenizer
|
3 |
+
from model.dataset import load_dataset
|
4 |
+
|
5 |
+
|
6 |
+
# -------------------------- Dataset Settings --------------------------- #
|
7 |
+
|
8 |
+
target_sample_rate = 24000
|
9 |
+
n_mel_channels = 100
|
10 |
+
hop_length = 256
|
11 |
+
|
12 |
+
tokenizer = "pinyin"
|
13 |
+
dataset_name = "Emilia_ZH_EN"
|
14 |
+
|
15 |
+
|
16 |
+
# -------------------------- Training Settings -------------------------- #
|
17 |
+
|
18 |
+
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
19 |
+
|
20 |
+
learning_rate = 7.5e-5
|
21 |
+
|
22 |
+
batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
|
23 |
+
batch_size_type = "frame" # "frame" or "sample"
|
24 |
+
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
25 |
+
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
26 |
+
max_grad_norm = 1.
|
27 |
+
|
28 |
+
epochs = 11 # use linear decay, thus epochs control the slope
|
29 |
+
num_warmup_updates = 20000 # warmup steps
|
30 |
+
save_per_updates = 50000 # save checkpoint per steps
|
31 |
+
last_per_steps = 5000 # save last checkpoint per steps
|
32 |
+
|
33 |
+
# model params
|
34 |
+
if exp_name == "F5TTS_Base":
|
35 |
+
wandb_resume_id = None
|
36 |
+
model_cls = DiT
|
37 |
+
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
38 |
+
elif exp_name == "E2TTS_Base":
|
39 |
+
wandb_resume_id = None
|
40 |
+
model_cls = UNetT
|
41 |
+
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
42 |
+
|
43 |
+
|
44 |
+
# ----------------------------------------------------------------------- #
|
45 |
+
|
46 |
+
def main():
|
47 |
+
|
48 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
49 |
+
|
50 |
+
mel_spec_kwargs = dict(
|
51 |
+
target_sample_rate = target_sample_rate,
|
52 |
+
n_mel_channels = n_mel_channels,
|
53 |
+
hop_length = hop_length,
|
54 |
+
)
|
55 |
+
|
56 |
+
e2tts = CFM(
|
57 |
+
transformer = model_cls(
|
58 |
+
**model_cfg,
|
59 |
+
text_num_embeds = vocab_size,
|
60 |
+
mel_dim = n_mel_channels
|
61 |
+
),
|
62 |
+
mel_spec_kwargs = mel_spec_kwargs,
|
63 |
+
vocab_char_map = vocab_char_map,
|
64 |
+
)
|
65 |
+
|
66 |
+
trainer = Trainer(
|
67 |
+
e2tts,
|
68 |
+
epochs,
|
69 |
+
learning_rate,
|
70 |
+
num_warmup_updates = num_warmup_updates,
|
71 |
+
save_per_updates = save_per_updates,
|
72 |
+
checkpoint_path = f'ckpts/{exp_name}',
|
73 |
+
batch_size = batch_size_per_gpu,
|
74 |
+
batch_size_type = batch_size_type,
|
75 |
+
max_samples = max_samples,
|
76 |
+
grad_accumulation_steps = grad_accumulation_steps,
|
77 |
+
max_grad_norm = max_grad_norm,
|
78 |
+
wandb_project = "CFM-TTS",
|
79 |
+
wandb_run_name = exp_name,
|
80 |
+
wandb_resume_id = wandb_resume_id,
|
81 |
+
last_per_steps = last_per_steps,
|
82 |
+
)
|
83 |
+
|
84 |
+
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
85 |
+
trainer.train(train_dataset,
|
86 |
+
resumable_with_seed = 666 # seed for shuffling dataset
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
main()
|
tests/ref_audio/test_en_1_ref_short.wav
ADDED
Binary file (256 kB). View file
|
|
tests/ref_audio/test_zh_1_ref_short.wav
ADDED
Binary file (325 kB). View file
|
|