txya900619 commited on
Commit
ec8f857
·
1 Parent(s): 4ab241b

feat: upload new app.py

Browse files
Files changed (3) hide show
  1. app.py +35 -16
  2. patch/e2_tts_pytorch.py +155 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
 
3
  import gradio as gr
 
4
  import spaces
5
  import torch
6
- from e2_tts_pytorch import E2TTS, DurationPredictor
7
  from huggingface_hub import snapshot_download
8
  from omegaconf import OmegaConf
9
  from tokenizers import Tokenizer
10
  from transformers import PreTrainedTokenizerFast
11
 
12
  from ipa.ipa import get_ipa, parse_ipa
 
13
 
14
 
15
  def load_model(model_id):
@@ -67,24 +69,32 @@ models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
67
 
68
 
69
  @spaces.GPU
70
- def _do_tts(model_id, ipa, ref_wav, ref_transcript):
71
- model = models_config[model_id]["model"].cuda()
72
- generated = model.sample(
73
- cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(),
74
- text=[ref_transcript + ipa],
75
- steps=32,
76
- cfg_strength=1.0,
77
- )[0]
78
- return generated.cpu().numpy()
 
 
 
 
 
 
79
 
80
 
81
  def text_to_speech(
82
  model_id: str,
 
 
 
 
83
  text: str,
84
  ref_wav: str,
85
  ref_transcript: str,
86
- dialect: str,
87
- # speed: float,
88
  ):
89
  if len(text) == 0:
90
  raise gr.Error("請勿輸入空字串。")
@@ -96,13 +106,13 @@ def text_to_speech(
96
  parsed_ipa = parse_ipa(ipa)
97
  if dialect == "nansixian":
98
  dialect = "sixian"
99
- models_config[model_id]["model"].tts_model.length_scale = speed
100
 
101
  wav = _do_tts(
102
  model_id,
103
  parsed_ipa,
104
  ref_wav,
105
  ref_transcript,
 
106
  )
107
 
108
  return (
@@ -180,12 +190,20 @@ with demo:
180
  ref_wav = gr.Audio(
181
  visible=False,
182
  type="filepath",
 
 
 
183
  waveform_options=gr.WaveformOptions(
184
  show_controls=False,
185
  sample_rate=24000,
186
  ),
187
  )
188
- ref_transcript = gr.Textbox(visible=False)
 
 
 
 
 
189
 
190
  speaker_wav = gr.Audio(
191
  label="客製化語音",
@@ -259,12 +277,13 @@ with demo:
259
  text_to_speech,
260
  inputs=[
261
  model_drop_down,
262
- input_text,
263
  use_default_or_custom_radio,
264
- speaker_wav,
265
  speaker_drop_down,
266
  dialect_radio,
267
  speed,
 
 
 
268
  ],
269
  outputs=[
270
  gr.Textbox(interactive=False, label="斷詞"),
 
1
  import os
2
 
3
  import gradio as gr
4
+ import librosa
5
  import spaces
6
  import torch
7
+ from e2_tts_pytorch import DurationPredictor
8
  from huggingface_hub import snapshot_download
9
  from omegaconf import OmegaConf
10
  from tokenizers import Tokenizer
11
  from transformers import PreTrainedTokenizerFast
12
 
13
  from ipa.ipa import get_ipa, parse_ipa
14
+ from patch.e2_tts_pytorch import E2TTSPatched as E2TTS
15
 
16
 
17
  def load_model(model_id):
 
69
 
70
 
71
  @spaces.GPU
72
+ def _do_tts(model_id, ipa, ref_wav, ref_transcript, speed):
73
+ with torch.inference_mode():
74
+ model = models_config[model_id]["model"].cuda()
75
+ ref_wav = librosa.load(ref_wav, sr=model.sampling_rate)[0]
76
+ print(ref_transcript + ipa)
77
+ text = model.tokenizer([ref_transcript + ipa]).to(model.device)
78
+
79
+ generated = model.sample(
80
+ cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(),
81
+ text=text,
82
+ steps=32,
83
+ cfg_strength=1.0,
84
+ speed=speed,
85
+ )[0]
86
+ return generated.cpu().numpy()
87
 
88
 
89
  def text_to_speech(
90
  model_id: str,
91
+ use_default_or_custom: str,
92
+ speaker_name: str,
93
+ dialect: str,
94
+ speed: float,
95
  text: str,
96
  ref_wav: str,
97
  ref_transcript: str,
 
 
98
  ):
99
  if len(text) == 0:
100
  raise gr.Error("請勿輸入空字串。")
 
106
  parsed_ipa = parse_ipa(ipa)
107
  if dialect == "nansixian":
108
  dialect = "sixian"
 
109
 
110
  wav = _do_tts(
111
  model_id,
112
  parsed_ipa,
113
  ref_wav,
114
  ref_transcript,
115
+ speed,
116
  )
117
 
118
  return (
 
190
  ref_wav = gr.Audio(
191
  visible=False,
192
  type="filepath",
193
+ value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
194
+ "ref_wav"
195
+ ],
196
  waveform_options=gr.WaveformOptions(
197
  show_controls=False,
198
  sample_rate=24000,
199
  ),
200
  )
201
+ ref_transcript = gr.Textbox(
202
+ value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
203
+ "ref_transcript"
204
+ ],
205
+ visible=False,
206
+ )
207
 
208
  speaker_wav = gr.Audio(
209
  label="客製化語音",
 
277
  text_to_speech,
278
  inputs=[
279
  model_drop_down,
 
280
  use_default_or_custom_radio,
 
281
  speaker_drop_down,
282
  dialect_radio,
283
  speed,
284
+ input_text,
285
+ ref_wav,
286
+ ref_transcript,
287
  ],
288
  outputs=[
289
  gr.Textbox(interactive=False, label="斷詞"),
patch/e2_tts_pytorch.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from e2_tts_pytorch import E2TTS
10
+ from e2_tts_pytorch.e2_tts import Float, Int, exists, lens_to_mask
11
+ from einops import rearrange
12
+ from torchdiffeq import odeint
13
+
14
+
15
+ class E2TTSPatched(E2TTS):
16
+ @torch.no_grad()
17
+ def sample(
18
+ self,
19
+ cond: Float["b n d"] | Float["b nw"],
20
+ *,
21
+ text: Int["b nt"] | list[str] | None = None,
22
+ lens: Int["b"] | None = None,
23
+ duration: int | Int["b"] | None = None,
24
+ steps=32,
25
+ cfg_strength=1.0, # they used a classifier free guidance strength of 1.
26
+ max_duration=4096, # in case the duration predictor goes haywire
27
+ vocoder: Callable[[Float["b d n"]], list[Float["_"]]] | None = None,
28
+ return_raw_output: bool | None = None,
29
+ save_to_filename: str | None = None,
30
+ speed: float = 1.0,
31
+ ) -> (Float["b n d"], list[Float["_"]]):
32
+ self.eval()
33
+
34
+ # raw wave
35
+
36
+ if cond.ndim == 2:
37
+ cond = self.mel_spec(cond)
38
+ cond = rearrange(cond, "b d n -> b n d")
39
+ assert cond.shape[-1] == self.num_channels
40
+
41
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
42
+
43
+ if not exists(lens):
44
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
45
+
46
+ # text
47
+
48
+ if isinstance(text, list):
49
+ text = self.tokenizer(text).to(device)
50
+ assert text.shape[0] == batch
51
+
52
+ if exists(text):
53
+ text_lens = (text != -1).sum(dim=-1)
54
+ lens = torch.maximum(
55
+ text_lens, lens
56
+ ) # make sure lengths are at least those of the text characters
57
+
58
+ # duration
59
+
60
+ cond_mask = lens_to_mask(lens)
61
+
62
+ if exists(duration):
63
+ if isinstance(duration, int):
64
+ duration = torch.full(
65
+ (batch,), duration, device=device, dtype=torch.long
66
+ )
67
+
68
+ elif exists(self.duration_predictor):
69
+ duration = (
70
+ self.duration_predictor(cond, text=text, lens=lens, return_loss=False)
71
+ * speed
72
+ ).long()
73
+
74
+ duration = torch.maximum(
75
+ lens + 1, duration
76
+ ) # just add one token so something is generated
77
+ duration = duration.clamp(max=max_duration)
78
+
79
+ assert duration.shape[0] == batch
80
+
81
+ max_duration = duration.amax()
82
+
83
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
84
+ cond_mask = F.pad(
85
+ cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
86
+ )
87
+ cond_mask = rearrange(cond_mask, "... -> ... 1")
88
+
89
+ mask = lens_to_mask(duration)
90
+
91
+ # neural ode
92
+
93
+ def fn(t, x):
94
+ # at each step, conditioning is fixed
95
+
96
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
97
+
98
+ # predict flow
99
+
100
+ return self.cfg_transformer_with_pred_head(
101
+ x, step_cond, times=t, text=text, mask=mask, cfg_strength=cfg_strength
102
+ )
103
+
104
+ y0 = torch.randn_like(cond)
105
+ t = torch.linspace(0, 1, steps, device=self.device)
106
+
107
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
108
+ sampled = trajectory[-1]
109
+
110
+ out = sampled
111
+
112
+ out = torch.where(cond_mask, cond, out)
113
+
114
+ # able to return raw untransformed output, if not using mel rep
115
+
116
+ if exists(return_raw_output) and return_raw_output:
117
+ return out
118
+
119
+ # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on
120
+
121
+ if exists(vocoder):
122
+ assert not exists(
123
+ self.vocos
124
+ ), "`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling"
125
+ out = rearrange(out, "b n d -> b d n")
126
+ out = vocoder(out)
127
+
128
+ elif exists(self.vocos):
129
+ audio = []
130
+ for mel, one_mask in zip(out, mask):
131
+ one_out = mel[one_mask]
132
+
133
+ one_out = rearrange(one_out, "n d -> 1 d n")
134
+ one_audio = self.vocos.decode(one_out)
135
+ one_audio = rearrange(one_audio, "1 nw -> nw")
136
+ audio.append(one_audio)
137
+
138
+ out = audio
139
+
140
+ if exists(save_to_filename):
141
+ assert exists(vocoder) or exists(self.vocos)
142
+ assert exists(self.sampling_rate)
143
+
144
+ path = Path(save_to_filename)
145
+ parent_path = path.parents[0]
146
+ parent_path.mkdir(exist_ok=True, parents=True)
147
+
148
+ for ind, one_audio in enumerate(out):
149
+ one_audio = rearrange(one_audio, "nw -> 1 nw")
150
+ save_path = str(parent_path / f"{ind + 1}.{path.name}")
151
+ torchaudio.save(
152
+ save_path, one_audio.detach().cpu(), sample_rate=self.sampling_rate
153
+ )
154
+
155
+ return out
requirements.txt CHANGED
@@ -3,4 +3,5 @@ opencc
3
  omegaconf
4
  e2_tts_pytorch
5
  transformers
6
- matplotlib
 
 
3
  omegaconf
4
  e2_tts_pytorch
5
  transformers
6
+ matplotlib
7
+ librosa