mrfakename commited on
Commit
3e38683
1 Parent(s): bd9a76a

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

.github/workflows/sync-hf.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to HF Space
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ trigger_curl:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Send cURL POST request
14
+ run: |
15
+ curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
16
+ -s \
17
+ -H "Content-Type: application/json" \
18
+ -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
+
3
+ USER root
4
+
5
+ ARG DEBIAN_FRONTEND=noninteractive
6
+
7
+ LABEL github_repo="https://github.com/SWivid/F5-TTS"
8
+
9
+ RUN set -x \
10
+ && apt-get update \
11
+ && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
+ && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
+ && rm -rf /var/lib/apt/lists/* \
14
+ && apt-get clean
15
+
16
+ WORKDIR /workspace
17
+
18
+ RUN git clone https://github.com/SWivid/F5-TTS.git \
19
+ && cd F5-TTS \
20
+ && pip install --no-cache-dir -r requirements.txt \
21
+ && pip install --no-cache-dir -r requirements_eval.txt
22
+
23
+ ENV SHELL=/bin/bash
24
+
25
+ WORKDIR /workspace/F5-TTS
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
  short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
- sdk_version: 5.1.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  app_file: app.py
8
  pinned: true
9
  short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
+ sdk_version: 4.44.1
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_REPO.md CHANGED
@@ -2,15 +2,20 @@
2
 
3
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
- [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
- [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
 
 
 
7
 
8
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
9
 
10
- **E2 TTS**: Flat-UNet Transformer, closest reproduction.
11
 
12
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
13
 
 
 
14
  ## Installation
15
 
16
  Clone the repository:
@@ -48,7 +53,7 @@ python scripts/prepare_emilia.py
48
  python scripts/prepare_wenetspeech4tts.py
49
  ```
50
 
51
- ## Training
52
 
53
  Once your datasets are prepared, you can start the training process.
54
 
@@ -60,9 +65,11 @@ accelerate launch train.py
60
  ```
61
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
62
 
 
 
63
  ## Inference
64
 
65
- To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
66
 
67
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
68
  - To avoid possible inference failures, make sure you have seen through the following instructions.
@@ -86,6 +93,9 @@ python inference-cli.py \
86
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
87
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
88
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
 
 
 
89
  ```
90
 
91
  ### Gradio App
@@ -148,6 +158,12 @@ bash scripts/eval_infer_batch.sh
148
 
149
  ### Objective Evaluation
150
 
 
 
 
 
 
 
151
  **Some Notes**
152
 
153
  For faster-whisper with CUDA 11:
@@ -178,11 +194,13 @@ python scripts/eval_librispeech_test_clean.py
178
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
179
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
180
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
181
- - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
182
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
183
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
 
 
184
 
185
  ## Citation
 
186
  ```
187
  @article{chen-etal-2024-f5tts,
188
  title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
@@ -193,4 +211,4 @@ python scripts/eval_librispeech_test_clean.py
193
  ```
194
  ## License
195
 
196
- Our code is released under MIT License.
 
2
 
3
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
6
+ [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
+ [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
+ [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
+ <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
10
 
11
  **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
12
 
13
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
14
 
15
  **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
16
 
17
+ ### Thanks to all the contributors !
18
+
19
  ## Installation
20
 
21
  Clone the repository:
 
53
  python scripts/prepare_wenetspeech4tts.py
54
  ```
55
 
56
+ ## Training & Finetuning
57
 
58
  Once your datasets are prepared, you can start the training process.
59
 
 
65
  ```
66
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
67
 
68
+ Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
69
+
70
  ## Inference
71
 
72
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
73
 
74
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
75
  - To avoid possible inference failures, make sure you have seen through the following instructions.
 
93
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
94
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
95
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
96
+
97
+ # Multi voice
98
+ python inference-cli.py -c samples/story.toml
99
  ```
100
 
101
  ### Gradio App
 
158
 
159
  ### Objective Evaluation
160
 
161
+ Install packages for evaluation:
162
+
163
+ ```bash
164
+ pip install -r requirements_eval.txt
165
+ ```
166
+
167
  **Some Notes**
168
 
169
  For faster-whisper with CUDA 11:
 
194
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
195
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
196
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
 
197
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
198
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
199
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
200
+ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation of F5-TTS, with the MLX framework.
201
 
202
  ## Citation
203
+ If our work and codebase is useful for you, please cite as:
204
  ```
205
  @article{chen-etal-2024-f5tts,
206
  title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
 
211
  ```
212
  ## License
213
 
214
+ Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import torch
4
  import torchaudio
@@ -17,7 +16,6 @@ from model.utils import (
17
  save_spectrogram,
18
  )
19
  from transformers import pipeline
20
- import librosa
21
  import click
22
  import soundfile as sf
23
 
@@ -33,19 +31,6 @@ def gpu_decorator(func):
33
  else:
34
  return func
35
 
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
  device = (
50
  "cuda"
51
  if torch.cuda.is_available()
@@ -73,7 +58,6 @@ cfg_strength = 2.0
73
  ode_method = "euler"
74
  sway_sampling_coef = -1.0
75
  speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
  fix_duration = None
78
 
79
 
@@ -114,104 +98,37 @@ E2TTS_ema_model = load_model(
114
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
  )
116
 
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
 
154
  for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
  else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
 
213
- @spaces.GPU
214
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
 
 
 
 
 
215
  if exp_name == "F5-TTS":
216
  ema_model = F5TTS_ema_model
217
  elif exp_name == "E2-TTS":
@@ -269,8 +186,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
269
  generated_waves.append(generated_wave)
270
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
271
 
272
- # Combine all generated waves
273
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  # Remove silence
276
  if remove_silence:
@@ -295,12 +248,8 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
295
 
296
  return (target_sample_rate, final_wave), spectrogram_path
297
 
298
- @spaces.GPU
299
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
300
- if not custom_split_words.strip():
301
- custom_words = [word.strip() for word in custom_split_words.split(',')]
302
- global SPLIT_WORDS
303
- SPLIT_WORDS = custom_words
304
 
305
  print(gen_text)
306
 
@@ -308,7 +257,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
308
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
309
  aseg = AudioSegment.from_file(ref_audio_orig)
310
 
311
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
 
312
  non_silent_wave = AudioSegment.silent(duration=0)
313
  for non_silent_seg in non_silent_segs:
314
  non_silent_wave += non_silent_seg
@@ -334,18 +285,27 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
334
  else:
335
  gr.Info("Using custom reference text...")
336
 
337
- # Split the input text into batches
 
 
 
 
 
 
338
  audio, sr = torchaudio.load(ref_audio)
339
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
340
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
 
 
341
  print('ref_text', ref_text)
342
- for i, gen_text in enumerate(gen_text_batches):
343
- print(f'gen_text {i}', gen_text)
344
 
345
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
346
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
 
347
 
348
- @spaces.GPU
349
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
350
  # Split the script into speaker blocks
351
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
@@ -429,6 +389,7 @@ with gr.Blocks() as app_credits:
429
 
430
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
431
  * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
 
432
  """)
433
  with gr.Blocks() as app_tts:
434
  gr.Markdown("# Batched TTS")
@@ -447,12 +408,7 @@ with gr.Blocks() as app_tts:
447
  remove_silence = gr.Checkbox(
448
  label="Remove Silences",
449
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
450
- value=True,
451
- )
452
- split_words_input = gr.Textbox(
453
- label="Custom Split Words",
454
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
455
- lines=2,
456
  )
457
  speed_slider = gr.Slider(
458
  label="Speed",
@@ -462,6 +418,14 @@ with gr.Blocks() as app_tts:
462
  step=0.1,
463
  info="Adjust the speed of the audio.",
464
  )
 
 
 
 
 
 
 
 
465
  speed_slider.change(update_speed, inputs=speed_slider)
466
 
467
  audio_output = gr.Audio(label="Synthesized Audio")
@@ -475,7 +439,7 @@ with gr.Blocks() as app_tts:
475
  gen_text_input,
476
  model_choice,
477
  remove_silence,
478
- split_words_input,
479
  ],
480
  outputs=[audio_output, spectrogram_output],
481
  )
@@ -568,8 +532,8 @@ with gr.Blocks() as app_emotional:
568
  regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
569
  regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
570
 
571
- # Additional speech types (up to 9 more)
572
- max_speech_types = 10
573
  speech_type_names = []
574
  speech_type_audios = []
575
  speech_type_ref_texts = []
@@ -681,8 +645,7 @@ with gr.Blocks() as app_emotional:
681
 
682
  # Output audio
683
  audio_output_emotional = gr.Audio(label="Synthesized Audio")
684
-
685
- @spaces.GPU
686
  def generate_emotional_speech(
687
  regular_audio,
688
  regular_ref_text,
@@ -724,7 +687,7 @@ with gr.Blocks() as app_emotional:
724
  ref_text = speech_types[current_emotion].get('ref_text', '')
725
 
726
  # Generate speech for this segment
727
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
728
  sr, audio_data = audio
729
 
730
  generated_audio_segments.append(audio_data)
@@ -805,4 +768,27 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
805
  )
806
  gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
807
 
808
- app.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import torch
3
  import torchaudio
 
16
  save_spectrogram,
17
  )
18
  from transformers import pipeline
 
19
  import click
20
  import soundfile as sf
21
 
 
31
  else:
32
  return func
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = (
35
  "cuda"
36
  if torch.cuda.is_available()
 
58
  ode_method = "euler"
59
  sway_sampling_coef = -1.0
60
  speed = 1.0
 
61
  fix_duration = None
62
 
63
 
 
98
  "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
99
  )
100
 
101
+ def chunk_text(text, max_chars=135):
102
+ """
103
+ Splits the input text into chunks, each with a maximum number of characters.
104
+
105
+ Args:
106
+ text (str): The text to be split.
107
+ max_chars (int): The maximum number of characters per chunk.
108
+
109
+ Returns:
110
+ List[str]: A list of text chunks.
111
+ """
112
+ chunks = []
113
+ current_chunk = ""
114
+ # Split the text into sentences based on punctuation followed by whitespace
115
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  for sentence in sentences:
118
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
119
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
120
  else:
121
+ if current_chunk:
122
+ chunks.append(current_chunk.strip())
123
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if current_chunk:
126
+ chunks.append(current_chunk.strip())
127
+
128
+ return chunks
129
+
130
+ @gpu_decorator
131
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
132
  if exp_name == "F5-TTS":
133
  ema_model = F5TTS_ema_model
134
  elif exp_name == "E2-TTS":
 
186
  generated_waves.append(generated_wave)
187
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
188
 
189
+ # Combine all generated waves with cross-fading
190
+ if cross_fade_duration <= 0:
191
+ # Simply concatenate
192
+ final_wave = np.concatenate(generated_waves)
193
+ else:
194
+ final_wave = generated_waves[0]
195
+ for i in range(1, len(generated_waves)):
196
+ prev_wave = final_wave
197
+ next_wave = generated_waves[i]
198
+
199
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
200
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
201
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
202
+
203
+ if cross_fade_samples <= 0:
204
+ # No overlap possible, concatenate
205
+ final_wave = np.concatenate([prev_wave, next_wave])
206
+ continue
207
+
208
+ # Overlapping parts
209
+ prev_overlap = prev_wave[-cross_fade_samples:]
210
+ next_overlap = next_wave[:cross_fade_samples]
211
+
212
+ # Fade out and fade in
213
+ fade_out = np.linspace(1, 0, cross_fade_samples)
214
+ fade_in = np.linspace(0, 1, cross_fade_samples)
215
+
216
+ # Cross-faded overlap
217
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
218
+
219
+ # Combine
220
+ new_wave = np.concatenate([
221
+ prev_wave[:-cross_fade_samples],
222
+ cross_faded_overlap,
223
+ next_wave[cross_fade_samples:]
224
+ ])
225
+
226
+ final_wave = new_wave
227
 
228
  # Remove silence
229
  if remove_silence:
 
248
 
249
  return (target_sample_rate, final_wave), spectrogram_path
250
 
251
+ @gpu_decorator
252
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
 
 
 
 
253
 
254
  print(gen_text)
255
 
 
257
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
258
  aseg = AudioSegment.from_file(ref_audio_orig)
259
 
260
+ non_silent_segs = silence.split_on_silence(
261
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
262
+ )
263
  non_silent_wave = AudioSegment.silent(duration=0)
264
  for non_silent_seg in non_silent_segs:
265
  non_silent_wave += non_silent_seg
 
285
  else:
286
  gr.Info("Using custom reference text...")
287
 
288
+ # Add the functionality to ensure it ends with ". "
289
+ if not ref_text.endswith(". "):
290
+ if ref_text.endswith("."):
291
+ ref_text += " "
292
+ else:
293
+ ref_text += ". "
294
+
295
  audio, sr = torchaudio.load(ref_audio)
296
+
297
+ # Use the new chunk_text function to split gen_text
298
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
299
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
300
  print('ref_text', ref_text)
301
+ for i, batch_text in enumerate(gen_text_batches):
302
+ print(f'gen_text {i}', batch_text)
303
 
304
  gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
305
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
306
+
307
 
308
+ @gpu_decorator
309
  def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
310
  # Split the script into speaker blocks
311
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
 
389
 
390
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
391
  * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
392
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation
393
  """)
394
  with gr.Blocks() as app_tts:
395
  gr.Markdown("# Batched TTS")
 
408
  remove_silence = gr.Checkbox(
409
  label="Remove Silences",
410
  info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
411
+ value=False,
 
 
 
 
 
412
  )
413
  speed_slider = gr.Slider(
414
  label="Speed",
 
418
  step=0.1,
419
  info="Adjust the speed of the audio.",
420
  )
421
+ cross_fade_duration_slider = gr.Slider(
422
+ label="Cross-Fade Duration (s)",
423
+ minimum=0.0,
424
+ maximum=1.0,
425
+ value=0.15,
426
+ step=0.01,
427
+ info="Set the duration of the cross-fade between audio clips.",
428
+ )
429
  speed_slider.change(update_speed, inputs=speed_slider)
430
 
431
  audio_output = gr.Audio(label="Synthesized Audio")
 
439
  gen_text_input,
440
  model_choice,
441
  remove_silence,
442
+ cross_fade_duration_slider,
443
  ],
444
  outputs=[audio_output, spectrogram_output],
445
  )
 
532
  regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
533
  regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
534
 
535
+ # Additional speech types (up to 99 more)
536
+ max_speech_types = 100
537
  speech_type_names = []
538
  speech_type_audios = []
539
  speech_type_ref_texts = []
 
645
 
646
  # Output audio
647
  audio_output_emotional = gr.Audio(label="Synthesized Audio")
648
+ @gpu_decorator
 
649
  def generate_emotional_speech(
650
  regular_audio,
651
  regular_ref_text,
 
687
  ref_text = speech_types[current_emotion].get('ref_text', '')
688
 
689
  # Generate speech for this segment
690
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
691
  sr, audio_data = audio
692
 
693
  generated_audio_segments.append(audio_data)
 
768
  )
769
  gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
770
 
771
+ @click.command()
772
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
773
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
774
+ @click.option(
775
+ "--share",
776
+ "-s",
777
+ default=False,
778
+ is_flag=True,
779
+ help="Share the app via Gradio share link",
780
+ )
781
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
782
+ def main(port, host, share, api):
783
+ global app
784
+ print(f"Starting app...")
785
+ app.queue(api_open=api).launch(
786
+ server_name=host, server_port=port, share=share, show_api=api
787
+ )
788
+
789
+
790
+ if __name__ == "__main__":
791
+ if not USING_SPACES:
792
+ main()
793
+ else:
794
+ app.queue().launch()
finetune-cli.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
3
+ from model.utils import get_tokenizer
4
+ from model.dataset import load_dataset
5
+ from cached_path import cached_path
6
+ import shutil,os
7
+ # -------------------------- Dataset Settings --------------------------- #
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
+
15
+ # -------------------------- Argument Parsing --------------------------- #
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser(description='Train CFM Model')
18
+
19
+ parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
20
+ parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
21
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
22
+ parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
23
+ parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
24
+ parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
25
+ parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
26
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
27
+ parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
28
+ parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
+ parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
+ parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
+ parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
+
33
+ return parser.parse_args()
34
+
35
+ # -------------------------- Training Settings -------------------------- #
36
+
37
+ def main():
38
+ args = parse_args()
39
+
40
+
41
+ # Model parameters based on experiment name
42
+ if args.exp_name == "F5TTS_Base":
43
+ wandb_resume_id = None
44
+ model_cls = DiT
45
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
+ if args.finetune:
47
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
+ elif args.exp_name == "E2TTS_Base":
49
+ wandb_resume_id = None
50
+ model_cls = UNetT
51
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ if args.finetune:
53
+ ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
+
55
+ if args.finetune:
56
+ path_ckpt = os.path.join("ckpts",args.dataset_name)
57
+ if os.path.isdir(path_ckpt)==False:
58
+ os.makedirs(path_ckpt,exist_ok=True)
59
+ shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
+
61
+ checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
+
63
+ # Use the dataset_name provided in the command line
64
+ tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
65
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
66
+
67
+ mel_spec_kwargs = dict(
68
+ target_sample_rate=target_sample_rate,
69
+ n_mel_channels=n_mel_channels,
70
+ hop_length=hop_length,
71
+ )
72
+
73
+ e2tts = CFM(
74
+ transformer=model_cls(
75
+ **model_cfg,
76
+ text_num_embeds=vocab_size,
77
+ mel_dim=n_mel_channels
78
+ ),
79
+ mel_spec_kwargs=mel_spec_kwargs,
80
+ vocab_char_map=vocab_char_map,
81
+ )
82
+
83
+ trainer = Trainer(
84
+ e2tts,
85
+ args.epochs,
86
+ args.learning_rate,
87
+ num_warmup_updates=args.num_warmup_updates,
88
+ save_per_updates=args.save_per_updates,
89
+ checkpoint_path=checkpoint_path,
90
+ batch_size=args.batch_size_per_gpu,
91
+ batch_size_type=args.batch_size_type,
92
+ max_samples=args.max_samples,
93
+ grad_accumulation_steps=args.grad_accumulation_steps,
94
+ max_grad_norm=args.max_grad_norm,
95
+ wandb_project="CFM-TTS",
96
+ wandb_run_name=args.exp_name,
97
+ wandb_resume_id=wandb_resume_id,
98
+ last_per_steps=args.last_per_steps,
99
+ )
100
+
101
+ train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
102
+ trainer.train(train_dataset,
103
+ resumable_with_seed=666 # seed for shuffling dataset
104
+ )
105
+
106
+
107
+ if __name__ == '__main__':
108
+ main()
finetune_gradio.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ from transformers import pipeline
4
+ import gradio as gr
5
+ import torch
6
+ import click
7
+ import torchaudio
8
+ from glob import glob
9
+ import librosa
10
+ import numpy as np
11
+ from scipy.io import wavfile
12
+ import shutil
13
+ import time
14
+
15
+ import json
16
+ from model.utils import convert_char_to_pinyin
17
+ import signal
18
+ import psutil
19
+ import platform
20
+ import subprocess
21
+ from datasets.arrow_writer import ArrowWriter
22
+
23
+ import json
24
+
25
+ training_process = None
26
+ system = platform.system()
27
+ python_executable = sys.executable or "python"
28
+
29
+ path_data="data"
30
+
31
+ device = (
32
+ "cuda"
33
+ if torch.cuda.is_available()
34
+ else "mps" if torch.backends.mps.is_available() else "cpu"
35
+ )
36
+
37
+ pipe = None
38
+
39
+ # Load metadata
40
+ def get_audio_duration(audio_path):
41
+ """Calculate the duration of an audio file."""
42
+ audio, sample_rate = torchaudio.load(audio_path)
43
+ num_channels = audio.shape[0]
44
+ return audio.shape[1] / (sample_rate * num_channels)
45
+
46
+ def clear_text(text):
47
+ """Clean and prepare text by lowering the case and stripping whitespace."""
48
+ return text.lower().strip()
49
+
50
+ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
51
+ padding = (int(frame_length // 2), int(frame_length // 2))
52
+ y = np.pad(y, padding, mode=pad_mode)
53
+
54
+ axis = -1
55
+ # put our new within-frame axis at the end for now
56
+ out_strides = y.strides + tuple([y.strides[axis]])
57
+ # Reduce the shape on the framing axis
58
+ x_shape_trimmed = list(y.shape)
59
+ x_shape_trimmed[axis] -= frame_length - 1
60
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
61
+ xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
62
+ if axis < 0:
63
+ target_axis = axis - 1
64
+ else:
65
+ target_axis = axis + 1
66
+ xw = np.moveaxis(xw, -1, target_axis)
67
+ # Downsample along the target axis
68
+ slices = [slice(None)] * xw.ndim
69
+ slices[axis] = slice(0, None, hop_length)
70
+ x = xw[tuple(slices)]
71
+
72
+ # Calculate power
73
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
74
+
75
+ return np.sqrt(power)
76
+
77
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
78
+ def __init__(
79
+ self,
80
+ sr: int,
81
+ threshold: float = -40.0,
82
+ min_length: int = 2000,
83
+ min_interval: int = 300,
84
+ hop_size: int = 20,
85
+ max_sil_kept: int = 2000,
86
+ ):
87
+ if not min_length >= min_interval >= hop_size:
88
+ raise ValueError(
89
+ "The following condition must be satisfied: min_length >= min_interval >= hop_size"
90
+ )
91
+ if not max_sil_kept >= hop_size:
92
+ raise ValueError(
93
+ "The following condition must be satisfied: max_sil_kept >= hop_size"
94
+ )
95
+ min_interval = sr * min_interval / 1000
96
+ self.threshold = 10 ** (threshold / 20.0)
97
+ self.hop_size = round(sr * hop_size / 1000)
98
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
99
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
100
+ self.min_interval = round(min_interval / self.hop_size)
101
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
102
+
103
+ def _apply_slice(self, waveform, begin, end):
104
+ if len(waveform.shape) > 1:
105
+ return waveform[
106
+ :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
107
+ ]
108
+ else:
109
+ return waveform[
110
+ begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
111
+ ]
112
+
113
+ # @timeit
114
+ def slice(self, waveform):
115
+ if len(waveform.shape) > 1:
116
+ samples = waveform.mean(axis=0)
117
+ else:
118
+ samples = waveform
119
+ if samples.shape[0] <= self.min_length:
120
+ return [waveform]
121
+ rms_list = get_rms(
122
+ y=samples, frame_length=self.win_size, hop_length=self.hop_size
123
+ ).squeeze(0)
124
+ sil_tags = []
125
+ silence_start = None
126
+ clip_start = 0
127
+ for i, rms in enumerate(rms_list):
128
+ # Keep looping while frame is silent.
129
+ if rms < self.threshold:
130
+ # Record start of silent frames.
131
+ if silence_start is None:
132
+ silence_start = i
133
+ continue
134
+ # Keep looping while frame is not silent and silence start has not been recorded.
135
+ if silence_start is None:
136
+ continue
137
+ # Clear recorded silence start if interval is not enough or clip is too short
138
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
139
+ need_slice_middle = (
140
+ i - silence_start >= self.min_interval
141
+ and i - clip_start >= self.min_length
142
+ )
143
+ if not is_leading_silence and not need_slice_middle:
144
+ silence_start = None
145
+ continue
146
+ # Need slicing. Record the range of silent frames to be removed.
147
+ if i - silence_start <= self.max_sil_kept:
148
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
149
+ if silence_start == 0:
150
+ sil_tags.append((0, pos))
151
+ else:
152
+ sil_tags.append((pos, pos))
153
+ clip_start = pos
154
+ elif i - silence_start <= self.max_sil_kept * 2:
155
+ pos = rms_list[
156
+ i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
157
+ ].argmin()
158
+ pos += i - self.max_sil_kept
159
+ pos_l = (
160
+ rms_list[
161
+ silence_start : silence_start + self.max_sil_kept + 1
162
+ ].argmin()
163
+ + silence_start
164
+ )
165
+ pos_r = (
166
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
167
+ + i
168
+ - self.max_sil_kept
169
+ )
170
+ if silence_start == 0:
171
+ sil_tags.append((0, pos_r))
172
+ clip_start = pos_r
173
+ else:
174
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
175
+ clip_start = max(pos_r, pos)
176
+ else:
177
+ pos_l = (
178
+ rms_list[
179
+ silence_start : silence_start + self.max_sil_kept + 1
180
+ ].argmin()
181
+ + silence_start
182
+ )
183
+ pos_r = (
184
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
185
+ + i
186
+ - self.max_sil_kept
187
+ )
188
+ if silence_start == 0:
189
+ sil_tags.append((0, pos_r))
190
+ else:
191
+ sil_tags.append((pos_l, pos_r))
192
+ clip_start = pos_r
193
+ silence_start = None
194
+ # Deal with trailing silence.
195
+ total_frames = rms_list.shape[0]
196
+ if (
197
+ silence_start is not None
198
+ and total_frames - silence_start >= self.min_interval
199
+ ):
200
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
201
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
202
+ sil_tags.append((pos, total_frames + 1))
203
+ # Apply and return slices.
204
+ ####音频+起始时间+终止时间
205
+ if len(sil_tags) == 0:
206
+ return [[waveform,0,int(total_frames*self.hop_size)]]
207
+ else:
208
+ chunks = []
209
+ if sil_tags[0][0] > 0:
210
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
211
+ for i in range(len(sil_tags) - 1):
212
+ chunks.append(
213
+ [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
214
+ )
215
+ if sil_tags[-1][1] < total_frames:
216
+ chunks.append(
217
+ [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
218
+ )
219
+ return chunks
220
+
221
+ #terminal
222
+ def terminate_process_tree(pid, including_parent=True):
223
+ try:
224
+ parent = psutil.Process(pid)
225
+ except psutil.NoSuchProcess:
226
+ # Process already terminated
227
+ return
228
+
229
+ children = parent.children(recursive=True)
230
+ for child in children:
231
+ try:
232
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
233
+ except OSError:
234
+ pass
235
+ if including_parent:
236
+ try:
237
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
238
+ except OSError:
239
+ pass
240
+
241
+ def terminate_process(pid):
242
+ if system == "Windows":
243
+ cmd = f"taskkill /t /f /pid {pid}"
244
+ os.system(cmd)
245
+ else:
246
+ terminate_process_tree(pid)
247
+
248
+ def start_training(dataset_name="",
249
+ exp_name="F5TTS_Base",
250
+ learning_rate=1e-4,
251
+ batch_size_per_gpu=400,
252
+ batch_size_type="frame",
253
+ max_samples=64,
254
+ grad_accumulation_steps=1,
255
+ max_grad_norm=1.0,
256
+ epochs=11,
257
+ num_warmup_updates=200,
258
+ save_per_updates=400,
259
+ last_per_steps=800,
260
+ finetune=True,
261
+ ):
262
+
263
+
264
+ global training_process
265
+
266
+ path_project = os.path.join(path_data, dataset_name + "_pinyin")
267
+
268
+ if os.path.isdir(path_project)==False:
269
+ yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
270
+ return
271
+
272
+ file_raw = os.path.join(path_project,"raw.arrow")
273
+ if os.path.isfile(file_raw)==False:
274
+ yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275
+ return
276
+
277
+ # Check if a training process is already running
278
+ if training_process is not None:
279
+ return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
280
+
281
+ yield "start train",gr.update(interactive=False),gr.update(interactive=False)
282
+
283
+ # Command to run the training script with the specified arguments
284
+ cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
285
+ f"--learning_rate {learning_rate} " \
286
+ f"--batch_size_per_gpu {batch_size_per_gpu} " \
287
+ f"--batch_size_type {batch_size_type} " \
288
+ f"--max_samples {max_samples} " \
289
+ f"--grad_accumulation_steps {grad_accumulation_steps} " \
290
+ f"--max_grad_norm {max_grad_norm} " \
291
+ f"--epochs {epochs} " \
292
+ f"--num_warmup_updates {num_warmup_updates} " \
293
+ f"--save_per_updates {save_per_updates} " \
294
+ f"--last_per_steps {last_per_steps} " \
295
+ f"--dataset_name {dataset_name}"
296
+ if finetune:cmd += f" --finetune {finetune}"
297
+ print(cmd)
298
+ try:
299
+ # Start the training process
300
+ training_process = subprocess.Popen(cmd, shell=True)
301
+
302
+ time.sleep(5)
303
+ yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True)
304
+
305
+ # Wait for the training process to finish
306
+ training_process.wait()
307
+ time.sleep(1)
308
+
309
+ if training_process is None:
310
+ text_info = 'train stop'
311
+ else:
312
+ text_info = "train complete !"
313
+
314
+ except Exception as e: # Catch all exceptions
315
+ # Ensure that we reset the training process variable in case of an error
316
+ text_info=f"An error occurred: {str(e)}"
317
+
318
+ training_process=None
319
+
320
+ yield text_info,gr.update(interactive=True),gr.update(interactive=False)
321
+
322
+ def stop_training():
323
+ global training_process
324
+ if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False)
325
+ terminate_process_tree(training_process.pid)
326
+ training_process = None
327
+ return 'train stop',gr.update(interactive=True),gr.update(interactive=False)
328
+
329
+ def create_data_project(name):
330
+ name+="_pinyin"
331
+ os.makedirs(os.path.join(path_data,name),exist_ok=True)
332
+ os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
333
+
334
+ def transcribe(file_audio,language="english"):
335
+ global pipe
336
+
337
+ if pipe is None:
338
+ pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device)
339
+
340
+ text_transcribe = pipe(
341
+ file_audio,
342
+ chunk_length_s=30,
343
+ batch_size=128,
344
+ generate_kwargs={"task": "transcribe","language": language},
345
+ return_timestamps=False,
346
+ )["text"].strip()
347
+ return text_transcribe
348
+
349
+ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
350
+ name_project+="_pinyin"
351
+ path_project= os.path.join(path_data,name_project)
352
+ path_dataset = os.path.join(path_project,"dataset")
353
+ path_project_wavs = os.path.join(path_project,"wavs")
354
+ file_metadata = os.path.join(path_project,"metadata.csv")
355
+
356
+ if audio_files is None:return "You need to load an audio file."
357
+
358
+ if os.path.isdir(path_project_wavs):
359
+ shutil.rmtree(path_project_wavs)
360
+
361
+ if os.path.isfile(file_metadata):
362
+ os.remove(file_metadata)
363
+
364
+ os.makedirs(path_project_wavs,exist_ok=True)
365
+
366
+ if user:
367
+ file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
368
+ if file_audios==[]:return "No audio file was found in the dataset."
369
+ else:
370
+ file_audios = audio_files
371
+
372
+
373
+ alpha = 0.5
374
+ _max = 1.0
375
+ slicer = Slicer(24000)
376
+
377
+ num = 0
378
+ error_num = 0
379
+ data=""
380
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
381
+
382
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
383
+
384
+ list_slicer=slicer.slice(audio)
385
+ for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
386
+
387
+ name_segment = os.path.join(f"segment_{num}")
388
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
389
+
390
+ tmp_max = np.abs(chunk).max()
391
+ if(tmp_max>1):chunk/=tmp_max
392
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
393
+ wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
394
+
395
+ try:
396
+ text=transcribe(file_segment,language)
397
+ text = text.lower().strip().replace('"',"")
398
+
399
+ data+= f"{name_segment}|{text}\n"
400
+
401
+ num+=1
402
+ except:
403
+ error_num +=1
404
+
405
+ with open(file_metadata,"w",encoding="utf-8") as f:
406
+ f.write(data)
407
+
408
+ if error_num!=[]:
409
+ error_text=f"\nerror files : {error_num}"
410
+ else:
411
+ error_text=""
412
+
413
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
414
+
415
+ def format_seconds_to_hms(seconds):
416
+ hours = int(seconds / 3600)
417
+ minutes = int((seconds % 3600) / 60)
418
+ seconds = seconds % 60
419
+ return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
420
+
421
+ def create_metadata(name_project,progress=gr.Progress()):
422
+ name_project+="_pinyin"
423
+ path_project= os.path.join(path_data,name_project)
424
+ path_project_wavs = os.path.join(path_project,"wavs")
425
+ file_metadata = os.path.join(path_project,"metadata.csv")
426
+ file_raw = os.path.join(path_project,"raw.arrow")
427
+ file_duration = os.path.join(path_project,"duration.json")
428
+ file_vocab = os.path.join(path_project,"vocab.txt")
429
+
430
+ if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
431
+
432
+ with open(file_metadata,"r",encoding="utf-8") as f:
433
+ data=f.read()
434
+
435
+ audio_path_list=[]
436
+ text_list=[]
437
+ duration_list=[]
438
+
439
+ count=data.split("\n")
440
+ lenght=0
441
+ result=[]
442
+ error_files=[]
443
+ for line in progress.tqdm(data.split("\n"),total=count):
444
+ sp_line=line.split("|")
445
+ if len(sp_line)!=2:continue
446
+ name_audio,text = sp_line[:2]
447
+
448
+ file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
449
+
450
+ if os.path.isfile(file_audio)==False:
451
+ error_files.append(file_audio)
452
+ continue
453
+
454
+ duraction = get_audio_duration(file_audio)
455
+ if duraction<2 and duraction>15:continue
456
+ if len(text)<4:continue
457
+
458
+ text = clear_text(text)
459
+ text = convert_char_to_pinyin([text], polyphone = True)[0]
460
+
461
+ audio_path_list.append(file_audio)
462
+ duration_list.append(duraction)
463
+ text_list.append(text)
464
+
465
+ result.append({"audio_path": file_audio, "text": text, "duration": duraction})
466
+
467
+ lenght+=duraction
468
+
469
+ if duration_list==[]:
470
+ error_files_text="\n".join(error_files)
471
+ return f"Error: No audio files found in the specified path : \n{error_files_text}"
472
+
473
+ min_second = round(min(duration_list),2)
474
+ max_second = round(max(duration_list),2)
475
+
476
+ with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
477
+ for line in progress.tqdm(result,total=len(result), desc=f"prepare data"):
478
+ writer.write(line)
479
+
480
+ with open(file_duration, 'w', encoding='utf-8') as f:
481
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
482
+
483
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
484
+ if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
485
+ shutil.copy2(file_vocab_finetune, file_vocab)
486
+
487
+ if error_files!=[]:
488
+ error_text="error files\n" + "\n".join(error_files)
489
+ else:
490
+ error_text=""
491
+
492
+ return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
493
+
494
+ def check_user(value):
495
+ return gr.update(visible=not value),gr.update(visible=value)
496
+
497
+ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune):
498
+ name_project+="_pinyin"
499
+ path_project= os.path.join(path_data,name_project)
500
+ file_duraction = os.path.join(path_project,"duration.json")
501
+
502
+ with open(file_duraction, 'r') as file:
503
+ data = json.load(file)
504
+
505
+ duration_list = data['duration']
506
+
507
+ samples = len(duration_list)
508
+
509
+ if torch.cuda.is_available():
510
+ gpu_properties = torch.cuda.get_device_properties(0)
511
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
512
+ elif torch.backends.mps.is_available():
513
+ total_memory = psutil.virtual_memory().available / (1024 ** 3)
514
+
515
+ if batch_size_type=="frame":
516
+ batch = int(total_memory * 0.5)
517
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
518
+ batch_size_per_gpu = int(38400 / batch )
519
+ else:
520
+ batch_size_per_gpu = int(total_memory / 8)
521
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
522
+ batch = batch_size_per_gpu
523
+
524
+ if batch_size_per_gpu<=0:batch_size_per_gpu=1
525
+
526
+ if samples<64:
527
+ max_samples = int(samples * 0.25)
528
+
529
+ num_warmup_updates = int(samples * 0.10)
530
+ save_per_updates = int(samples * 0.25)
531
+ last_per_steps =int(save_per_updates * 5)
532
+
533
+ max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
534
+ num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
535
+ save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
536
+ last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
537
+
538
+ if finetune:learning_rate=1e-4
539
+ else:learning_rate=7.5e-5
540
+
541
+ return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
542
+
543
+ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
544
+ try:
545
+ checkpoint = torch.load(checkpoint_path)
546
+ print("Original Checkpoint Keys:", checkpoint.keys())
547
+
548
+ ema_model_state_dict = checkpoint.get('ema_model_state_dict', None)
549
+
550
+ if ema_model_state_dict is not None:
551
+ new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
552
+ torch.save(new_checkpoint, new_checkpoint_path)
553
+ return f"New checkpoint saved at: {new_checkpoint_path}"
554
+ else:
555
+ return "No 'ema_model_state_dict' found in the checkpoint."
556
+
557
+ except Exception as e:
558
+ return f"An error occurred: {e}"
559
+
560
+ def vocab_check(project_name):
561
+ name_project = project_name + "_pinyin"
562
+ path_project = os.path.join(path_data, name_project)
563
+
564
+ file_metadata = os.path.join(path_project, "metadata.csv")
565
+
566
+ file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
567
+ if os.path.isfile(file_vocab)==False:
568
+ return f"the file {file_vocab} not found !"
569
+
570
+ with open(file_vocab,"r",encoding="utf-8") as f:
571
+ data=f.read()
572
+
573
+ vocab = data.split("\n")
574
+
575
+ if os.path.isfile(file_metadata)==False:
576
+ return f"the file {file_metadata} not found !"
577
+
578
+ with open(file_metadata,"r",encoding="utf-8") as f:
579
+ data=f.read()
580
+
581
+ miss_symbols=[]
582
+ miss_symbols_keep={}
583
+ for item in data.split("\n"):
584
+ sp=item.split("|")
585
+ if len(sp)!=2:continue
586
+ text=sp[1].lower().strip()
587
+
588
+ for t in text:
589
+ if (t in vocab)==False and (t in miss_symbols_keep)==False:
590
+ miss_symbols.append(t)
591
+ miss_symbols_keep[t]=t
592
+
593
+
594
+ if miss_symbols==[]:info ="You can train using your language !"
595
+ else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
596
+
597
+ return info
598
+
599
+
600
+
601
+ with gr.Blocks() as app:
602
+
603
+ with gr.Row():
604
+ project_name=gr.Textbox(label="project name",value="my_speak")
605
+ bt_create=gr.Button("create new project")
606
+
607
+ bt_create.click(fn=create_data_project,inputs=[project_name])
608
+
609
+ with gr.Tabs():
610
+
611
+
612
+ with gr.TabItem("transcribe Data"):
613
+
614
+
615
+ ch_manual = gr.Checkbox(label="user",value=False)
616
+
617
+ mark_info_transcribe=gr.Markdown(
618
+ """```plaintext
619
+ Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
620
+
621
+ my_speak/
622
+
623
+ └── dataset/
624
+ ├── audio1.wav
625
+ └── audio2.wav
626
+ ...
627
+ ```""",visible=False)
628
+
629
+ audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple")
630
+ txt_lang = gr.Text(label="Language",value="english")
631
+ bt_transcribe=bt_create=gr.Button("transcribe")
632
+ txt_info_transcribe=gr.Text(label="info",value="")
633
+ bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe])
634
+ ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe])
635
+
636
+ with gr.TabItem("prepare Data"):
637
+ gr.Markdown(
638
+ """```plaintext
639
+ place all your wavs folder and your metadata.csv file in {your name project}
640
+ my_speak/
641
+
642
+ ├── wavs/
643
+ │ ├── audio1.wav
644
+ │ └── audio2.wav
645
+ | ...
646
+
647
+ └── metadata.csv
648
+
649
+ file format metadata.csv
650
+
651
+ audio1|text1
652
+ audio2|text1
653
+ ...
654
+
655
+ ```""")
656
+
657
+ bt_prepare=bt_create=gr.Button("prepare")
658
+ txt_info_prepare=gr.Text(label="info",value="")
659
+ bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare])
660
+
661
+ with gr.TabItem("train Data"):
662
+
663
+ with gr.Row():
664
+ bt_calculate=bt_create=gr.Button("Auto Settings")
665
+ ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True)
666
+ lb_samples = gr.Label(label="samples")
667
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
668
+
669
+ with gr.Row():
670
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
671
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
672
+
673
+ with gr.Row():
674
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
675
+ max_samples = gr.Number(label="Max Samples", value=16)
676
+
677
+ with gr.Row():
678
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
679
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
680
+
681
+ with gr.Row():
682
+ epochs = gr.Number(label="Epochs", value=10)
683
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
684
+
685
+ with gr.Row():
686
+ save_per_updates = gr.Number(label="Save per Updates", value=10)
687
+ last_per_steps = gr.Number(label="Last per Steps", value=50)
688
+
689
+ with gr.Row():
690
+ start_button = gr.Button("Start Training")
691
+ stop_button = gr.Button("Stop Training",interactive=False)
692
+
693
+ txt_info_train=gr.Text(label="info",value="")
694
+ start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button])
695
+ stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
696
+ bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate])
697
+
698
+ with gr.TabItem("reduse checkpoint"):
699
+ txt_path_checkpoint = gr.Text(label="path checkpoint :")
700
+ txt_path_checkpoint_small = gr.Text(label="path output :")
701
+ txt_info_reduse = gr.Text(label="info",value="")
702
+ reduse_button = gr.Button("reduse")
703
+ reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
704
+
705
+ with gr.TabItem("vocab check experiment"):
706
+ check_button = gr.Button("check vocab")
707
+ txt_info_check=gr.Text(label="info",value="")
708
+ check_button.click(fn=vocab_check,inputs=[project_name],outputs=[txt_info_check])
709
+
710
+
711
+ @click.command()
712
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
713
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
714
+ @click.option(
715
+ "--share",
716
+ "-s",
717
+ default=False,
718
+ is_flag=True,
719
+ help="Share the app via Gradio share link",
720
+ )
721
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
722
+ def main(port, host, share, api):
723
+ global app
724
+ print(f"Starting app...")
725
+ app.queue(api_open=api).launch(
726
+ server_name=host, server_port=port, share=share, show_api=api
727
+ )
728
+
729
+ if __name__ == "__main__":
730
+ main()
inference-cli.py CHANGED
@@ -1,25 +1,24 @@
 
 
1
  import re
 
 
 
 
 
 
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
- import tempfile
6
  from einops import rearrange
7
- from vocos import Vocos
8
  from pydub import AudioSegment, silence
9
- from model import CFM, UNetT, DiT, MMDiT
10
- from cached_path import cached_path
11
- from model.utils import (
12
- load_checkpoint,
13
- get_tokenizer,
14
- convert_char_to_pinyin,
15
- save_spectrogram,
16
- )
17
  from transformers import pipeline
18
- import soundfile as sf
19
- import tomli
20
- import argparse
21
- import tqdm
22
- from pathlib import Path
23
 
24
  parser = argparse.ArgumentParser(
25
  prog="python3 inference-cli.py",
@@ -56,6 +55,12 @@ parser.add_argument(
56
  type=str,
57
  help="Text to generate.",
58
  )
 
 
 
 
 
 
59
  parser.add_argument(
60
  "-o",
61
  "--output_dir",
@@ -66,6 +71,11 @@ parser.add_argument(
66
  "--remove_silence",
67
  help="Remove silence.",
68
  )
 
 
 
 
 
69
  args = parser.parse_args()
70
 
71
  config = tomli.load(open(args.config, "rb"))
@@ -73,29 +83,31 @@ config = tomli.load(open(args.config, "rb"))
73
  ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
74
  ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
 
 
 
76
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
  model = args.model if args.model else config["model"]
78
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
79
  wave_path = Path(output_dir)/"out.wav"
80
  spectrogram_path = Path(output_dir)/"out.png"
81
-
82
- SPLIT_WORDS = [
83
- "but", "however", "nevertheless", "yet", "still",
84
- "therefore", "thus", "hence", "consequently",
85
- "moreover", "furthermore", "additionally",
86
- "meanwhile", "alternatively", "otherwise",
87
- "namely", "specifically", "for example", "such as",
88
- "in fact", "indeed", "notably",
89
- "in contrast", "on the other hand", "conversely",
90
- "in conclusion", "to summarize", "finally"
91
- ]
92
 
93
  device = (
94
  "cuda"
95
  if torch.cuda.is_available()
96
  else "mps" if torch.backends.mps.is_available() else "cpu"
97
  )
98
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
 
 
 
 
 
 
 
 
99
 
100
  print(f"Using {device} device")
101
 
@@ -114,8 +126,9 @@ speed = 1.0
114
  fix_duration = None
115
 
116
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
117
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
118
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
 
119
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
120
  model = CFM(
121
  transformer=model_cls(
@@ -143,103 +156,36 @@ F5TTS_model_cfg = dict(
143
  )
144
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
145
 
146
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
147
- if len(text.encode('utf-8')) <= max_chars:
148
- return [text]
149
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
150
- text += '.'
151
-
152
- sentences = re.split('([。.!?!?])', text)
153
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
154
-
155
- batches = []
156
- current_batch = ""
157
-
158
- def split_by_words(text):
159
- words = text.split()
160
- current_word_part = ""
161
- word_batches = []
162
- for word in words:
163
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
164
- current_word_part += word + ' '
165
- else:
166
- if current_word_part:
167
- # Try to find a suitable split word
168
- for split_word in split_words:
169
- split_index = current_word_part.rfind(' ' + split_word + ' ')
170
- if split_index != -1:
171
- word_batches.append(current_word_part[:split_index].strip())
172
- current_word_part = current_word_part[split_index:].strip() + ' '
173
- break
174
- else:
175
- # If no suitable split word found, just append the current part
176
- word_batches.append(current_word_part.strip())
177
- current_word_part = ""
178
- current_word_part += word + ' '
179
- if current_word_part:
180
- word_batches.append(current_word_part.strip())
181
- return word_batches
182
 
183
  for sentence in sentences:
184
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
185
- current_batch += sentence
186
  else:
187
- # If adding this sentence would exceed the limit
188
- if current_batch:
189
- batches.append(current_batch)
190
- current_batch = ""
191
-
192
- # If the sentence itself is longer than max_chars, split it
193
- if len(sentence.encode('utf-8')) > max_chars:
194
- # First, try to split by colon
195
- colon_parts = sentence.split(':')
196
- if len(colon_parts) > 1:
197
- for part in colon_parts:
198
- if len(part.encode('utf-8')) <= max_chars:
199
- batches.append(part)
200
- else:
201
- # If colon part is still too long, split by comma
202
- comma_parts = re.split('[,,]', part)
203
- if len(comma_parts) > 1:
204
- current_comma_part = ""
205
- for comma_part in comma_parts:
206
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
207
- current_comma_part += comma_part + ','
208
- else:
209
- if current_comma_part:
210
- batches.append(current_comma_part.rstrip(','))
211
- current_comma_part = comma_part + ','
212
- if current_comma_part:
213
- batches.append(current_comma_part.rstrip(','))
214
- else:
215
- # If no comma, split by words
216
- batches.extend(split_by_words(part))
217
- else:
218
- # If no colon, split by comma
219
- comma_parts = re.split('[,,]', sentence)
220
- if len(comma_parts) > 1:
221
- current_comma_part = ""
222
- for comma_part in comma_parts:
223
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
224
- current_comma_part += comma_part + ','
225
- else:
226
- if current_comma_part:
227
- batches.append(current_comma_part.rstrip(','))
228
- current_comma_part = comma_part + ','
229
- if current_comma_part:
230
- batches.append(current_comma_part.rstrip(','))
231
- else:
232
- # If no comma, split by words
233
- batches.extend(split_by_words(sentence))
234
- else:
235
- current_batch = sentence
236
-
237
- if current_batch:
238
- batches.append(current_batch)
239
-
240
- return batches
241
 
242
- def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
 
 
 
 
 
 
243
  if model == "F5-TTS":
244
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
245
  elif model == "E2-TTS":
@@ -297,41 +243,56 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
297
  generated_waves.append(generated_wave)
298
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
299
 
300
- # Combine all generated waves
301
- final_wave = np.concatenate(generated_waves)
302
-
303
- with open(wave_path, "wb") as f:
304
- sf.write(f.name, final_wave, target_sample_rate)
305
- # Remove silence
306
- if remove_silence:
307
- aseg = AudioSegment.from_file(f.name)
308
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
309
- non_silent_wave = AudioSegment.silent(duration=0)
310
- for non_silent_seg in non_silent_segs:
311
- non_silent_wave += non_silent_seg
312
- aseg = non_silent_wave
313
- aseg.export(f.name, format="wav")
314
- print(f.name)
315
 
316
- # Create a combined spectrogram
317
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
318
- save_spectrogram(combined_spectrogram, spectrogram_path)
319
- print(spectrogram_path)
320
 
 
 
 
 
321
 
322
- def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
323
- if not custom_split_words.strip():
324
- custom_words = [word.strip() for word in custom_split_words.split(',')]
325
- global SPLIT_WORDS
326
- SPLIT_WORDS = custom_words
327
 
328
- print(gen_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
 
 
 
330
  print("Converting audio...")
331
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
332
  aseg = AudioSegment.from_file(ref_audio_orig)
333
 
334
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
335
  non_silent_wave = AudioSegment.silent(duration=0)
336
  for non_silent_seg in non_silent_segs:
337
  non_silent_wave += non_silent_seg
@@ -362,17 +323,70 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
362
  print("Finished transcription")
363
  else:
364
  print("Using custom reference text...")
 
 
 
 
 
 
 
 
 
 
365
 
366
  # Split the input text into batches
367
  audio, sr = torchaudio.load(ref_audio)
368
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
369
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
370
  print('ref_text', ref_text)
371
  for i, gen_text in enumerate(gen_text_batches):
372
  print(f'gen_text {i}', gen_text)
373
 
374
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
375
- return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
376
 
377
 
378
- infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
  import re
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import tomli
10
  import torch
11
  import torchaudio
12
+ import tqdm
13
+ from cached_path import cached_path
14
  from einops import rearrange
 
15
  from pydub import AudioSegment, silence
 
 
 
 
 
 
 
 
16
  from transformers import pipeline
17
+ from vocos import Vocos
18
+
19
+ from model import CFM, DiT, MMDiT, UNetT
20
+ from model.utils import (convert_char_to_pinyin, get_tokenizer,
21
+ load_checkpoint, save_spectrogram)
22
 
23
  parser = argparse.ArgumentParser(
24
  prog="python3 inference-cli.py",
 
55
  type=str,
56
  help="Text to generate.",
57
  )
58
+ parser.add_argument(
59
+ "-f",
60
+ "--gen_file",
61
+ type=str,
62
+ help="File with text to generate. Ignores --text",
63
+ )
64
  parser.add_argument(
65
  "-o",
66
  "--output_dir",
 
71
  "--remove_silence",
72
  help="Remove silence.",
73
  )
74
+ parser.add_argument(
75
+ "--load_vocoder_from_local",
76
+ action="store_true",
77
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
78
+ )
79
  args = parser.parse_args()
80
 
81
  config = tomli.load(open(args.config, "rb"))
 
83
  ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
84
  ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
85
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
86
+ gen_file = args.gen_file if args.gen_file else config["gen_file"]
87
+ if gen_file:
88
+ gen_text = codecs.open(gen_file, "r", "utf-8").read()
89
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
90
  model = args.model if args.model else config["model"]
91
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92
  wave_path = Path(output_dir)/"out.wav"
93
  spectrogram_path = Path(output_dir)/"out.png"
94
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
 
 
 
 
 
 
 
 
 
 
95
 
96
  device = (
97
  "cuda"
98
  if torch.cuda.is_available()
99
  else "mps" if torch.backends.mps.is_available() else "cpu"
100
  )
101
+
102
+ if args.load_vocoder_from_local:
103
+ print(f"Load vocos from local path {vocos_local_path}")
104
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
105
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
106
+ vocos.load_state_dict(state_dict)
107
+ vocos.eval()
108
+ else:
109
+ print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
110
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
111
 
112
  print(f"Using {device} device")
113
 
 
126
  fix_duration = None
127
 
128
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
129
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
130
+ if not Path(ckpt_path).exists():
131
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
132
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
133
  model = CFM(
134
  transformer=model_cls(
 
156
  )
157
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158
 
159
+
160
+ def chunk_text(text, max_chars=135):
161
+ """
162
+ Splits the input text into chunks, each with a maximum number of characters.
163
+ Args:
164
+ text (str): The text to be split.
165
+ max_chars (int): The maximum number of characters per chunk.
166
+ Returns:
167
+ List[str]: A list of text chunks.
168
+ """
169
+ chunks = []
170
+ current_chunk = ""
171
+ # Split the text into sentences based on punctuation followed by whitespace
172
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  for sentence in sentences:
175
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
176
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
177
  else:
178
+ if current_chunk:
179
+ chunks.append(current_chunk.strip())
180
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ if current_chunk:
183
+ chunks.append(current_chunk.strip())
184
+
185
+ return chunks
186
+
187
+
188
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
189
  if model == "F5-TTS":
190
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
191
  elif model == "E2-TTS":
 
243
  generated_waves.append(generated_wave)
244
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
245
 
246
+ # Combine all generated waves with cross-fading
247
+ if cross_fade_duration <= 0:
248
+ # Simply concatenate
249
+ final_wave = np.concatenate(generated_waves)
250
+ else:
251
+ final_wave = generated_waves[0]
252
+ for i in range(1, len(generated_waves)):
253
+ prev_wave = final_wave
254
+ next_wave = generated_waves[i]
 
 
 
 
 
 
255
 
256
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
257
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
258
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
 
259
 
260
+ if cross_fade_samples <= 0:
261
+ # No overlap possible, concatenate
262
+ final_wave = np.concatenate([prev_wave, next_wave])
263
+ continue
264
 
265
+ # Overlapping parts
266
+ prev_overlap = prev_wave[-cross_fade_samples:]
267
+ next_overlap = next_wave[:cross_fade_samples]
 
 
268
 
269
+ # Fade out and fade in
270
+ fade_out = np.linspace(1, 0, cross_fade_samples)
271
+ fade_in = np.linspace(0, 1, cross_fade_samples)
272
+
273
+ # Cross-faded overlap
274
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
275
+
276
+ # Combine
277
+ new_wave = np.concatenate([
278
+ prev_wave[:-cross_fade_samples],
279
+ cross_faded_overlap,
280
+ next_wave[cross_fade_samples:]
281
+ ])
282
+
283
+ final_wave = new_wave
284
+
285
+ # Create a combined spectrogram
286
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
287
 
288
+ return final_wave, combined_spectrogram
289
+
290
+ def process_voice(ref_audio_orig, ref_text):
291
  print("Converting audio...")
292
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
293
  aseg = AudioSegment.from_file(ref_audio_orig)
294
 
295
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
296
  non_silent_wave = AudioSegment.silent(duration=0)
297
  for non_silent_seg in non_silent_segs:
298
  non_silent_wave += non_silent_seg
 
323
  print("Finished transcription")
324
  else:
325
  print("Using custom reference text...")
326
+ return ref_audio, ref_text
327
+
328
+ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
329
+ print(gen_text)
330
+ # Add the functionality to ensure it ends with ". "
331
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
332
+ if ref_text.endswith("."):
333
+ ref_text += " "
334
+ else:
335
+ ref_text += ". "
336
 
337
  # Split the input text into batches
338
  audio, sr = torchaudio.load(ref_audio)
339
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
340
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
341
  print('ref_text', ref_text)
342
  for i, gen_text in enumerate(gen_text_batches):
343
  print(f'gen_text {i}', gen_text)
344
 
345
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
346
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
347
 
348
 
349
+ def process(ref_audio, ref_text, text_gen, model, remove_silence):
350
+ main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
351
+ if "voices" not in config:
352
+ voices = {"main": main_voice}
353
+ else:
354
+ voices = config["voices"]
355
+ voices["main"] = main_voice
356
+ for voice in voices:
357
+ voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])
358
+
359
+ generated_audio_segments = []
360
+ reg1 = r'(?=\[\w+\])'
361
+ chunks = re.split(reg1, text_gen)
362
+ reg2 = r'\[(\w+)\]'
363
+ for text in chunks:
364
+ match = re.match(reg2, text)
365
+ if not match or voice not in voices:
366
+ voice = "main"
367
+ else:
368
+ voice = match[1]
369
+ text = re.sub(reg2, "", text)
370
+ gen_text = text.strip()
371
+ ref_audio = voices[voice]['ref_audio']
372
+ ref_text = voices[voice]['ref_text']
373
+ print(f"Voice: {voice}")
374
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
375
+ generated_audio_segments.append(audio)
376
+
377
+ if generated_audio_segments:
378
+ final_wave = np.concatenate(generated_audio_segments)
379
+ with open(wave_path, "wb") as f:
380
+ sf.write(f.name, final_wave, target_sample_rate)
381
+ # Remove silence
382
+ if remove_silence:
383
+ aseg = AudioSegment.from_file(f.name)
384
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
385
+ non_silent_wave = AudioSegment.silent(duration=0)
386
+ for non_silent_seg in non_silent_segs:
387
+ non_silent_wave += non_silent_seg
388
+ aseg = non_silent_wave
389
+ aseg.export(f.name, format="wav")
390
+ print(f.name)
391
+
392
+ process(ref_audio, ref_text, gen_text, model, remove_silence)
inference-cli.toml CHANGED
@@ -4,5 +4,7 @@ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
- remove_silence = true
 
 
8
  output_dir = "tests"
 
4
  # If an empty "", transcribes the reference audio automatically.
5
  ref_text = "Some call me nature, others call me mother nature."
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
+ # File with text to generate. Ignores the text above.
8
+ gen_file = ""
9
+ remove_silence = false
10
  output_dir = "tests"
model/dataset.py CHANGED
@@ -184,11 +184,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
184
 
185
  def load_dataset(
186
  dataset_name: str,
187
- tokenizer: str,
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
- ) -> CustomDataset:
 
 
 
 
192
 
193
  print("Loading dataset ...")
194
 
@@ -206,7 +210,18 @@ def load_dataset(
206
  data_dict = json.load(f)
207
  durations = data_dict["duration"]
208
  train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
209
-
 
 
 
 
 
 
 
 
 
 
 
210
  elif dataset_type == "HFDataset":
211
  print("Should manually modify the path of huggingface dataset to your need.\n" +
212
  "May also the corresponding script cuz different dataset may have different format.")
 
184
 
185
  def load_dataset(
186
  dataset_name: str,
187
+ tokenizer: str = "pinyin",
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+ '''
193
+ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
+ - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
+ '''
196
 
197
  print("Loading dataset ...")
198
 
 
210
  data_dict = json.load(f)
211
  durations = data_dict["duration"]
212
  train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
+
214
+ elif dataset_type == "CustomDatasetPath":
215
+ try:
216
+ train_dataset = load_from_disk(f"{dataset_name}/raw")
217
+ except:
218
+ train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
+
220
+ with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
+ data_dict = json.load(f)
222
+ durations = data_dict["duration"]
223
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
+
225
  elif dataset_type == "HFDataset":
226
  print("Should manually modify the path of huggingface dataset to your need.\n" +
227
  "May also the corresponding script cuz different dataset may have different format.")
model/trainer.py CHANGED
@@ -140,7 +140,7 @@ class Trainer:
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
 
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
model/utils.py CHANGED
@@ -22,12 +22,6 @@ from einops import rearrange, reduce
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
- import zhconv
26
- from zhon.hanzi import punctuation
27
- from jiwer import compute_measures
28
-
29
- from funasr import AutoModel
30
- from faster_whisper import WhisperModel
31
 
32
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
  from model.modules import MelSpec
@@ -129,6 +123,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
129
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
  - "char" for char-wise tokenizer, need .txt vocab_file
131
  - "byte" for utf-8 tokenizer
 
132
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
  - if use "byte", set to 256 (unicode byte range)
@@ -144,6 +139,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
144
  elif tokenizer == "byte":
145
  vocab_char_map = None
146
  vocab_size = 256
 
 
 
 
 
 
147
 
148
  return vocab_char_map, vocab_size
149
 
@@ -425,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
425
 
426
  def load_asr_model(lang, ckpt_dir = ""):
427
  if lang == "zh":
 
428
  model = AutoModel(
429
  model = os.path.join(ckpt_dir, "paraformer-zh"),
430
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
@@ -433,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
433
  disable_update=True,
434
  ) # following seed-tts setting
435
  elif lang == "en":
 
436
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
437
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
438
  return model
@@ -444,6 +447,7 @@ def run_asr_wer(args):
444
  rank, lang, test_set, ckpt_dir = args
445
 
446
  if lang == "zh":
 
447
  torch.cuda.set_device(rank)
448
  elif lang == "en":
449
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
@@ -451,10 +455,12 @@ def run_asr_wer(args):
451
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
452
 
453
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
454
-
 
455
  punctuation_all = punctuation + string.punctuation
456
  wers = []
457
 
 
458
  for gen_wav, prompt_wav, truth in tqdm(test_set):
459
  if lang == "zh":
460
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
@@ -503,7 +509,7 @@ def run_sim(args):
503
  device = f"cuda:{rank}"
504
 
505
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
506
- state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
507
  model.load_state_dict(state_dict['model'], strict=False)
508
 
509
  use_gpu=True if torch.cuda.is_available() else False
@@ -559,7 +565,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
559
  from safetensors.torch import load_file
560
  checkpoint = load_file(ckpt_path, device=device)
561
  else:
562
- checkpoint = torch.load(ckpt_path, map_location=device)
563
 
564
  if use_ema == True:
565
  ema_model = EMA(model, include_online_model = False).to(device)
 
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
 
 
 
 
 
 
25
 
26
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
  from model.modules import MelSpec
 
123
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
  - "char" for char-wise tokenizer, need .txt vocab_file
125
  - "byte" for utf-8 tokenizer
126
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
  - if use "byte", set to 256 (unicode byte range)
 
139
  elif tokenizer == "byte":
140
  vocab_char_map = None
141
  vocab_size = 256
142
+ elif tokenizer == "custom":
143
+ with open (dataset_name, "r", encoding="utf-8") as f:
144
+ vocab_char_map = {}
145
+ for i, char in enumerate(f):
146
+ vocab_char_map[char[:-1]] = i
147
+ vocab_size = len(vocab_char_map)
148
 
149
  return vocab_char_map, vocab_size
150
 
 
426
 
427
  def load_asr_model(lang, ckpt_dir = ""):
428
  if lang == "zh":
429
+ from funasr import AutoModel
430
  model = AutoModel(
431
  model = os.path.join(ckpt_dir, "paraformer-zh"),
432
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
 
435
  disable_update=True,
436
  ) # following seed-tts setting
437
  elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
  return model
 
447
  rank, lang, test_set, ckpt_dir = args
448
 
449
  if lang == "zh":
450
+ import zhconv
451
  torch.cuda.set_device(rank)
452
  elif lang == "en":
453
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
 
455
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
 
457
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
  punctuation_all = punctuation + string.punctuation
461
  wers = []
462
 
463
+ from jiwer import compute_measures
464
  for gen_wav, prompt_wav, truth in tqdm(test_set):
465
  if lang == "zh":
466
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
 
509
  device = f"cuda:{rank}"
510
 
511
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
  model.load_state_dict(state_dict['model'], strict=False)
514
 
515
  use_gpu=True if torch.cuda.is_available() else False
 
565
  from safetensors.torch import load_file
566
  checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
requirements.txt CHANGED
@@ -5,25 +5,19 @@ datasets
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
- faster_whisper
9
- funasr
10
  gradio
11
  jieba
12
- jiwer
13
  librosa
14
  matplotlib
15
- numpy==1.23.5
16
  pydub
17
  pypinyin
18
  safetensors
19
  soundfile
20
- # torch>=2.0
21
- # torchaudio>=2.3.0
22
  torchdiffeq
23
  tqdm>=4.65.0
24
  transformers
25
  vocos
26
  wandb
27
  x_transformers>=1.31.14
28
- zhconv
29
- zhon
 
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
 
 
8
  gradio
9
  jieba
 
10
  librosa
11
  matplotlib
12
+ numpy<=1.26.4
13
  pydub
14
  pypinyin
15
  safetensors
16
  soundfile
17
+ tomli
 
18
  torchdiffeq
19
  tqdm>=4.65.0
20
  transformers
21
  vocos
22
  wandb
23
  x_transformers>=1.31.14
 
 
requirements_eval.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ faster_whisper
2
+ funasr
3
+ jiwer
4
+ zhconv
5
+ zhon
samples/country.flac ADDED
Binary file (180 kB). View file
 
samples/main.flac ADDED
Binary file (279 kB). View file
 
samples/story.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "samples/main.flac"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = ""
6
+ gen_text = ""
7
+ # File with text to generate. Ignores the text above.
8
+ gen_file = "samples/story.txt"
9
+ remove_silence = true
10
+ output_dir = "samples"
11
+
12
+ [voices.town]
13
+ ref_audio = "samples/town.flac"
14
+ ref_text = ""
15
+
16
+ [voices.country]
17
+ ref_audio = "samples/country.flac"
18
+ ref_text = ""
19
+
samples/story.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
samples/town.flac ADDED
Binary file (229 kB). View file
 
scripts/eval_infer_batch.py CHANGED
@@ -127,7 +127,7 @@ local = False
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
 
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
scripts/prepare_csv_wavs.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from pathlib import Path
5
+ import json
6
+ import shutil
7
+ import argparse
8
+
9
+ import csv
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from datasets.arrow_writer import ArrowWriter
13
+
14
+ from model.utils import (
15
+ convert_char_to_pinyin,
16
+ )
17
+
18
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
+
20
+ def is_csv_wavs_format(input_dataset_dir):
21
+ fpath = Path(input_dataset_dir)
22
+ metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
+
26
+
27
+ def prepare_csv_wavs_dir(input_dir):
28
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
29
+ input_dir = Path(input_dir)
30
+ metadata_path = input_dir / "metadata.csv"
31
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
32
+
33
+ sub_result, durations = [], []
34
+ vocab_set = set()
35
+ polyphone = True
36
+ for audio_path, text in audio_path_text_pairs:
37
+ if not Path(audio_path).exists():
38
+ print(f"audio {audio_path} not found, skipping")
39
+ continue
40
+ audio_duration = get_audio_duration(audio_path)
41
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
42
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
43
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
44
+ durations.append(audio_duration)
45
+ vocab_set.update(list(text))
46
+
47
+ return sub_result, durations, vocab_set
48
+
49
+ def get_audio_duration(audio_path):
50
+ audio, sample_rate = torchaudio.load(audio_path)
51
+ num_channels = audio.shape[0]
52
+ return audio.shape[1] / (sample_rate * num_channels)
53
+
54
+ def read_audio_text_pairs(csv_file_path):
55
+ audio_text_pairs = []
56
+
57
+ parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
+ next(reader) # Skip the header row
61
+ for row in reader:
62
+ if len(row) >= 2:
63
+ audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
+ audio_file_path = parent / audio_file
66
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
67
+
68
+ return audio_text_pairs
69
+
70
+
71
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
72
+ out_dir = Path(out_dir)
73
+ # save preprocessed dataset to disk
74
+ out_dir.mkdir(exist_ok=True, parents=True)
75
+ print(f"\nSaving to {out_dir} ...")
76
+
77
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
78
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
+ raw_arrow_path = out_dir / "raw.arrow"
80
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
+ writer.write(line)
83
+
84
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
85
+ dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
+
89
+ # vocab map, i.e. tokenizer
90
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
91
+ # if tokenizer == "pinyin":
92
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
93
+ voca_out_path = out_dir / "vocab.txt"
94
+ with open(voca_out_path.as_posix(), "w") as f:
95
+ for vocab in sorted(text_vocab_set):
96
+ f.write(vocab + "\n")
97
+
98
+ if is_finetune:
99
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
100
+ shutil.copy2(file_vocab_finetune, voca_out_path)
101
+ else:
102
+ with open(voca_out_path, "w") as f:
103
+ for vocab in sorted(text_vocab_set):
104
+ f.write(vocab + "\n")
105
+
106
+ dataset_name = out_dir.stem
107
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
108
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
109
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
110
+
111
+
112
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
113
+ if is_finetune:
114
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
115
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
116
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
117
+
118
+
119
+ def cli():
120
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
+
127
+ args = parser.parse_args()
128
+
129
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
+
131
+ if __name__ == "__main__":
132
+ cli()
speech_edit.py CHANGED
@@ -85,8 +85,9 @@ local = False
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
  vocos.load_state_dict(state_dict)
 
90
  vocos.eval()
91
  else:
92
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
89
  vocos.load_state_dict(state_dict)
90
+
91
  vocos.eval()
92
  else:
93
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
train.py CHANGED
@@ -9,10 +9,10 @@ target_sample_rate = 24000
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
- tokenizer = "pinyin"
 
13
  dataset_name = "Emilia_ZH_EN"
14
 
15
-
16
  # -------------------------- Training Settings -------------------------- #
17
 
18
  exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
@@ -44,8 +44,11 @@ elif exp_name == "E2TTS_Base":
44
  # ----------------------------------------------------------------------- #
45
 
46
  def main():
47
-
48
- vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
 
 
 
49
 
50
  mel_spec_kwargs = dict(
51
  target_sample_rate = target_sample_rate,
 
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
  dataset_name = "Emilia_ZH_EN"
15
 
 
16
  # -------------------------- Training Settings -------------------------- #
17
 
18
  exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
 
44
  # ----------------------------------------------------------------------- #
45
 
46
  def main():
47
+ if tokenizer == "custom":
48
+ tokenizer_path = tokenizer_path
49
+ else:
50
+ tokenizer_path = dataset_name
51
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
52
 
53
  mel_spec_kwargs = dict(
54
  target_sample_rate = target_sample_rate,