Spaces:
Runtime error
Runtime error
Commit
·
ec8f857
1
Parent(s):
4ab241b
feat: upload new app.py
Browse files- app.py +35 -16
- patch/e2_tts_pytorch.py +155 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
import os
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
-
from e2_tts_pytorch import
|
7 |
from huggingface_hub import snapshot_download
|
8 |
from omegaconf import OmegaConf
|
9 |
from tokenizers import Tokenizer
|
10 |
from transformers import PreTrainedTokenizerFast
|
11 |
|
12 |
from ipa.ipa import get_ipa, parse_ipa
|
|
|
13 |
|
14 |
|
15 |
def load_model(model_id):
|
@@ -67,24 +69,32 @@ models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
|
|
67 |
|
68 |
|
69 |
@spaces.GPU
|
70 |
-
def _do_tts(model_id, ipa, ref_wav, ref_transcript):
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
def text_to_speech(
|
82 |
model_id: str,
|
|
|
|
|
|
|
|
|
83 |
text: str,
|
84 |
ref_wav: str,
|
85 |
ref_transcript: str,
|
86 |
-
dialect: str,
|
87 |
-
# speed: float,
|
88 |
):
|
89 |
if len(text) == 0:
|
90 |
raise gr.Error("請勿輸入空字串。")
|
@@ -96,13 +106,13 @@ def text_to_speech(
|
|
96 |
parsed_ipa = parse_ipa(ipa)
|
97 |
if dialect == "nansixian":
|
98 |
dialect = "sixian"
|
99 |
-
models_config[model_id]["model"].tts_model.length_scale = speed
|
100 |
|
101 |
wav = _do_tts(
|
102 |
model_id,
|
103 |
parsed_ipa,
|
104 |
ref_wav,
|
105 |
ref_transcript,
|
|
|
106 |
)
|
107 |
|
108 |
return (
|
@@ -180,12 +190,20 @@ with demo:
|
|
180 |
ref_wav = gr.Audio(
|
181 |
visible=False,
|
182 |
type="filepath",
|
|
|
|
|
|
|
183 |
waveform_options=gr.WaveformOptions(
|
184 |
show_controls=False,
|
185 |
sample_rate=24000,
|
186 |
),
|
187 |
)
|
188 |
-
ref_transcript = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
speaker_wav = gr.Audio(
|
191 |
label="客製化語音",
|
@@ -259,12 +277,13 @@ with demo:
|
|
259 |
text_to_speech,
|
260 |
inputs=[
|
261 |
model_drop_down,
|
262 |
-
input_text,
|
263 |
use_default_or_custom_radio,
|
264 |
-
speaker_wav,
|
265 |
speaker_drop_down,
|
266 |
dialect_radio,
|
267 |
speed,
|
|
|
|
|
|
|
268 |
],
|
269 |
outputs=[
|
270 |
gr.Textbox(interactive=False, label="斷詞"),
|
|
|
1 |
import os
|
2 |
|
3 |
import gradio as gr
|
4 |
+
import librosa
|
5 |
import spaces
|
6 |
import torch
|
7 |
+
from e2_tts_pytorch import DurationPredictor
|
8 |
from huggingface_hub import snapshot_download
|
9 |
from omegaconf import OmegaConf
|
10 |
from tokenizers import Tokenizer
|
11 |
from transformers import PreTrainedTokenizerFast
|
12 |
|
13 |
from ipa.ipa import get_ipa, parse_ipa
|
14 |
+
from patch.e2_tts_pytorch import E2TTSPatched as E2TTS
|
15 |
|
16 |
|
17 |
def load_model(model_id):
|
|
|
69 |
|
70 |
|
71 |
@spaces.GPU
|
72 |
+
def _do_tts(model_id, ipa, ref_wav, ref_transcript, speed):
|
73 |
+
with torch.inference_mode():
|
74 |
+
model = models_config[model_id]["model"].cuda()
|
75 |
+
ref_wav = librosa.load(ref_wav, sr=model.sampling_rate)[0]
|
76 |
+
print(ref_transcript + ipa)
|
77 |
+
text = model.tokenizer([ref_transcript + ipa]).to(model.device)
|
78 |
+
|
79 |
+
generated = model.sample(
|
80 |
+
cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(),
|
81 |
+
text=text,
|
82 |
+
steps=32,
|
83 |
+
cfg_strength=1.0,
|
84 |
+
speed=speed,
|
85 |
+
)[0]
|
86 |
+
return generated.cpu().numpy()
|
87 |
|
88 |
|
89 |
def text_to_speech(
|
90 |
model_id: str,
|
91 |
+
use_default_or_custom: str,
|
92 |
+
speaker_name: str,
|
93 |
+
dialect: str,
|
94 |
+
speed: float,
|
95 |
text: str,
|
96 |
ref_wav: str,
|
97 |
ref_transcript: str,
|
|
|
|
|
98 |
):
|
99 |
if len(text) == 0:
|
100 |
raise gr.Error("請勿輸入空字串。")
|
|
|
106 |
parsed_ipa = parse_ipa(ipa)
|
107 |
if dialect == "nansixian":
|
108 |
dialect = "sixian"
|
|
|
109 |
|
110 |
wav = _do_tts(
|
111 |
model_id,
|
112 |
parsed_ipa,
|
113 |
ref_wav,
|
114 |
ref_transcript,
|
115 |
+
speed,
|
116 |
)
|
117 |
|
118 |
return (
|
|
|
190 |
ref_wav = gr.Audio(
|
191 |
visible=False,
|
192 |
type="filepath",
|
193 |
+
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
|
194 |
+
"ref_wav"
|
195 |
+
],
|
196 |
waveform_options=gr.WaveformOptions(
|
197 |
show_controls=False,
|
198 |
sample_rate=24000,
|
199 |
),
|
200 |
)
|
201 |
+
ref_transcript = gr.Textbox(
|
202 |
+
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
|
203 |
+
"ref_transcript"
|
204 |
+
],
|
205 |
+
visible=False,
|
206 |
+
)
|
207 |
|
208 |
speaker_wav = gr.Audio(
|
209 |
label="客製化語音",
|
|
|
277 |
text_to_speech,
|
278 |
inputs=[
|
279 |
model_drop_down,
|
|
|
280 |
use_default_or_custom_radio,
|
|
|
281 |
speaker_drop_down,
|
282 |
dialect_radio,
|
283 |
speed,
|
284 |
+
input_text,
|
285 |
+
ref_wav,
|
286 |
+
ref_transcript,
|
287 |
],
|
288 |
outputs=[
|
289 |
gr.Textbox(interactive=False, label="斷詞"),
|
patch/e2_tts_pytorch.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from e2_tts_pytorch import E2TTS
|
10 |
+
from e2_tts_pytorch.e2_tts import Float, Int, exists, lens_to_mask
|
11 |
+
from einops import rearrange
|
12 |
+
from torchdiffeq import odeint
|
13 |
+
|
14 |
+
|
15 |
+
class E2TTSPatched(E2TTS):
|
16 |
+
@torch.no_grad()
|
17 |
+
def sample(
|
18 |
+
self,
|
19 |
+
cond: Float["b n d"] | Float["b nw"],
|
20 |
+
*,
|
21 |
+
text: Int["b nt"] | list[str] | None = None,
|
22 |
+
lens: Int["b"] | None = None,
|
23 |
+
duration: int | Int["b"] | None = None,
|
24 |
+
steps=32,
|
25 |
+
cfg_strength=1.0, # they used a classifier free guidance strength of 1.
|
26 |
+
max_duration=4096, # in case the duration predictor goes haywire
|
27 |
+
vocoder: Callable[[Float["b d n"]], list[Float["_"]]] | None = None,
|
28 |
+
return_raw_output: bool | None = None,
|
29 |
+
save_to_filename: str | None = None,
|
30 |
+
speed: float = 1.0,
|
31 |
+
) -> (Float["b n d"], list[Float["_"]]):
|
32 |
+
self.eval()
|
33 |
+
|
34 |
+
# raw wave
|
35 |
+
|
36 |
+
if cond.ndim == 2:
|
37 |
+
cond = self.mel_spec(cond)
|
38 |
+
cond = rearrange(cond, "b d n -> b n d")
|
39 |
+
assert cond.shape[-1] == self.num_channels
|
40 |
+
|
41 |
+
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
42 |
+
|
43 |
+
if not exists(lens):
|
44 |
+
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
45 |
+
|
46 |
+
# text
|
47 |
+
|
48 |
+
if isinstance(text, list):
|
49 |
+
text = self.tokenizer(text).to(device)
|
50 |
+
assert text.shape[0] == batch
|
51 |
+
|
52 |
+
if exists(text):
|
53 |
+
text_lens = (text != -1).sum(dim=-1)
|
54 |
+
lens = torch.maximum(
|
55 |
+
text_lens, lens
|
56 |
+
) # make sure lengths are at least those of the text characters
|
57 |
+
|
58 |
+
# duration
|
59 |
+
|
60 |
+
cond_mask = lens_to_mask(lens)
|
61 |
+
|
62 |
+
if exists(duration):
|
63 |
+
if isinstance(duration, int):
|
64 |
+
duration = torch.full(
|
65 |
+
(batch,), duration, device=device, dtype=torch.long
|
66 |
+
)
|
67 |
+
|
68 |
+
elif exists(self.duration_predictor):
|
69 |
+
duration = (
|
70 |
+
self.duration_predictor(cond, text=text, lens=lens, return_loss=False)
|
71 |
+
* speed
|
72 |
+
).long()
|
73 |
+
|
74 |
+
duration = torch.maximum(
|
75 |
+
lens + 1, duration
|
76 |
+
) # just add one token so something is generated
|
77 |
+
duration = duration.clamp(max=max_duration)
|
78 |
+
|
79 |
+
assert duration.shape[0] == batch
|
80 |
+
|
81 |
+
max_duration = duration.amax()
|
82 |
+
|
83 |
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
84 |
+
cond_mask = F.pad(
|
85 |
+
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
86 |
+
)
|
87 |
+
cond_mask = rearrange(cond_mask, "... -> ... 1")
|
88 |
+
|
89 |
+
mask = lens_to_mask(duration)
|
90 |
+
|
91 |
+
# neural ode
|
92 |
+
|
93 |
+
def fn(t, x):
|
94 |
+
# at each step, conditioning is fixed
|
95 |
+
|
96 |
+
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
97 |
+
|
98 |
+
# predict flow
|
99 |
+
|
100 |
+
return self.cfg_transformer_with_pred_head(
|
101 |
+
x, step_cond, times=t, text=text, mask=mask, cfg_strength=cfg_strength
|
102 |
+
)
|
103 |
+
|
104 |
+
y0 = torch.randn_like(cond)
|
105 |
+
t = torch.linspace(0, 1, steps, device=self.device)
|
106 |
+
|
107 |
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
108 |
+
sampled = trajectory[-1]
|
109 |
+
|
110 |
+
out = sampled
|
111 |
+
|
112 |
+
out = torch.where(cond_mask, cond, out)
|
113 |
+
|
114 |
+
# able to return raw untransformed output, if not using mel rep
|
115 |
+
|
116 |
+
if exists(return_raw_output) and return_raw_output:
|
117 |
+
return out
|
118 |
+
|
119 |
+
# take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on
|
120 |
+
|
121 |
+
if exists(vocoder):
|
122 |
+
assert not exists(
|
123 |
+
self.vocos
|
124 |
+
), "`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling"
|
125 |
+
out = rearrange(out, "b n d -> b d n")
|
126 |
+
out = vocoder(out)
|
127 |
+
|
128 |
+
elif exists(self.vocos):
|
129 |
+
audio = []
|
130 |
+
for mel, one_mask in zip(out, mask):
|
131 |
+
one_out = mel[one_mask]
|
132 |
+
|
133 |
+
one_out = rearrange(one_out, "n d -> 1 d n")
|
134 |
+
one_audio = self.vocos.decode(one_out)
|
135 |
+
one_audio = rearrange(one_audio, "1 nw -> nw")
|
136 |
+
audio.append(one_audio)
|
137 |
+
|
138 |
+
out = audio
|
139 |
+
|
140 |
+
if exists(save_to_filename):
|
141 |
+
assert exists(vocoder) or exists(self.vocos)
|
142 |
+
assert exists(self.sampling_rate)
|
143 |
+
|
144 |
+
path = Path(save_to_filename)
|
145 |
+
parent_path = path.parents[0]
|
146 |
+
parent_path.mkdir(exist_ok=True, parents=True)
|
147 |
+
|
148 |
+
for ind, one_audio in enumerate(out):
|
149 |
+
one_audio = rearrange(one_audio, "nw -> 1 nw")
|
150 |
+
save_path = str(parent_path / f"{ind + 1}.{path.name}")
|
151 |
+
torchaudio.save(
|
152 |
+
save_path, one_audio.detach().cpu(), sample_rate=self.sampling_rate
|
153 |
+
)
|
154 |
+
|
155 |
+
return out
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ opencc
|
|
3 |
omegaconf
|
4 |
e2_tts_pytorch
|
5 |
transformers
|
6 |
-
matplotlib
|
|
|
|
3 |
omegaconf
|
4 |
e2_tts_pytorch
|
5 |
transformers
|
6 |
+
matplotlib
|
7 |
+
librosa
|