Hemant0000 commited on
Commit
8e3c3c1
Β·
verified Β·
1 Parent(s): 8c9588c

Create finetune_gradio.py

Browse files
Files changed (1) hide show
  1. finetune_gradio.py +944 -0
finetune_gradio.py ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import tempfile
5
+ import random
6
+ from transformers import pipeline
7
+ import gradio as gr
8
+ import torch
9
+ import gc
10
+ import click
11
+ import torchaudio
12
+ from glob import glob
13
+ import librosa
14
+ import numpy as np
15
+ from scipy.io import wavfile
16
+ import shutil
17
+ import time
18
+
19
+ import json
20
+ from model.utils import convert_char_to_pinyin
21
+ import signal
22
+ import psutil
23
+ import platform
24
+ import subprocess
25
+ from datasets.arrow_writer import ArrowWriter
26
+ from datasets import Dataset as Dataset_
27
+ from api import F5TTS
28
+
29
+
30
+ training_process = None
31
+ system = platform.system()
32
+ python_executable = sys.executable or "python"
33
+ tts_api = None
34
+ last_checkpoint = ""
35
+ last_device = ""
36
+
37
+ path_data = "data"
38
+
39
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
40
+
41
+ pipe = None
42
+
43
+
44
+ # Load metadata
45
+ def get_audio_duration(audio_path):
46
+ """Calculate the duration of an audio file."""
47
+ audio, sample_rate = torchaudio.load(audio_path)
48
+ num_channels = audio.shape[0]
49
+ return audio.shape[1] / (sample_rate * num_channels)
50
+
51
+
52
+ def clear_text(text):
53
+ """Clean and prepare text by lowering the case and stripping whitespace."""
54
+ return text.lower().strip()
55
+
56
+
57
+ def get_rms(
58
+ y,
59
+ frame_length=2048,
60
+ hop_length=512,
61
+ pad_mode="constant",
62
+ ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
63
+ padding = (int(frame_length // 2), int(frame_length // 2))
64
+ y = np.pad(y, padding, mode=pad_mode)
65
+
66
+ axis = -1
67
+ # put our new within-frame axis at the end for now
68
+ out_strides = y.strides + tuple([y.strides[axis]])
69
+ # Reduce the shape on the framing axis
70
+ x_shape_trimmed = list(y.shape)
71
+ x_shape_trimmed[axis] -= frame_length - 1
72
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
73
+ xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
74
+ if axis < 0:
75
+ target_axis = axis - 1
76
+ else:
77
+ target_axis = axis + 1
78
+ xw = np.moveaxis(xw, -1, target_axis)
79
+ # Downsample along the target axis
80
+ slices = [slice(None)] * xw.ndim
81
+ slices[axis] = slice(0, None, hop_length)
82
+ x = xw[tuple(slices)]
83
+
84
+ # Calculate power
85
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
86
+
87
+ return np.sqrt(power)
88
+
89
+
90
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
91
+ def __init__(
92
+ self,
93
+ sr: int,
94
+ threshold: float = -40.0,
95
+ min_length: int = 2000,
96
+ min_interval: int = 300,
97
+ hop_size: int = 20,
98
+ max_sil_kept: int = 2000,
99
+ ):
100
+ if not min_length >= min_interval >= hop_size:
101
+ raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
102
+ if not max_sil_kept >= hop_size:
103
+ raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
104
+ min_interval = sr * min_interval / 1000
105
+ self.threshold = 10 ** (threshold / 20.0)
106
+ self.hop_size = round(sr * hop_size / 1000)
107
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
108
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
109
+ self.min_interval = round(min_interval / self.hop_size)
110
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
111
+
112
+ def _apply_slice(self, waveform, begin, end):
113
+ if len(waveform.shape) > 1:
114
+ return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
115
+ else:
116
+ return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
117
+
118
+ # @timeit
119
+ def slice(self, waveform):
120
+ if len(waveform.shape) > 1:
121
+ samples = waveform.mean(axis=0)
122
+ else:
123
+ samples = waveform
124
+ if samples.shape[0] <= self.min_length:
125
+ return [waveform]
126
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
127
+ sil_tags = []
128
+ silence_start = None
129
+ clip_start = 0
130
+ for i, rms in enumerate(rms_list):
131
+ # Keep looping while frame is silent.
132
+ if rms < self.threshold:
133
+ # Record start of silent frames.
134
+ if silence_start is None:
135
+ silence_start = i
136
+ continue
137
+ # Keep looping while frame is not silent and silence start has not been recorded.
138
+ if silence_start is None:
139
+ continue
140
+ # Clear recorded silence start if interval is not enough or clip is too short
141
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
142
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
143
+ if not is_leading_silence and not need_slice_middle:
144
+ silence_start = None
145
+ continue
146
+ # Need slicing. Record the range of silent frames to be removed.
147
+ if i - silence_start <= self.max_sil_kept:
148
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
149
+ if silence_start == 0:
150
+ sil_tags.append((0, pos))
151
+ else:
152
+ sil_tags.append((pos, pos))
153
+ clip_start = pos
154
+ elif i - silence_start <= self.max_sil_kept * 2:
155
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
156
+ pos += i - self.max_sil_kept
157
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
158
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
159
+ if silence_start == 0:
160
+ sil_tags.append((0, pos_r))
161
+ clip_start = pos_r
162
+ else:
163
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
164
+ clip_start = max(pos_r, pos)
165
+ else:
166
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
167
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
168
+ if silence_start == 0:
169
+ sil_tags.append((0, pos_r))
170
+ else:
171
+ sil_tags.append((pos_l, pos_r))
172
+ clip_start = pos_r
173
+ silence_start = None
174
+ # Deal with trailing silence.
175
+ total_frames = rms_list.shape[0]
176
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
177
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
178
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
179
+ sil_tags.append((pos, total_frames + 1))
180
+ # Apply and return slices.
181
+ ####ιŸ³ι’‘+衷始既间+η»ˆζ­’ζ—Άι—΄
182
+ if len(sil_tags) == 0:
183
+ return [[waveform, 0, int(total_frames * self.hop_size)]]
184
+ else:
185
+ chunks = []
186
+ if sil_tags[0][0] > 0:
187
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
188
+ for i in range(len(sil_tags) - 1):
189
+ chunks.append(
190
+ [
191
+ self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
192
+ int(sil_tags[i][1] * self.hop_size),
193
+ int(sil_tags[i + 1][0] * self.hop_size),
194
+ ]
195
+ )
196
+ if sil_tags[-1][1] < total_frames:
197
+ chunks.append(
198
+ [
199
+ self._apply_slice(waveform, sil_tags[-1][1], total_frames),
200
+ int(sil_tags[-1][1] * self.hop_size),
201
+ int(total_frames * self.hop_size),
202
+ ]
203
+ )
204
+ return chunks
205
+
206
+
207
+ # terminal
208
+ def terminate_process_tree(pid, including_parent=True):
209
+ try:
210
+ parent = psutil.Process(pid)
211
+ except psutil.NoSuchProcess:
212
+ # Process already terminated
213
+ return
214
+
215
+ children = parent.children(recursive=True)
216
+ for child in children:
217
+ try:
218
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
219
+ except OSError:
220
+ pass
221
+ if including_parent:
222
+ try:
223
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
224
+ except OSError:
225
+ pass
226
+
227
+
228
+ def terminate_process(pid):
229
+ if system == "Windows":
230
+ cmd = f"taskkill /t /f /pid {pid}"
231
+ os.system(cmd)
232
+ else:
233
+ terminate_process_tree(pid)
234
+
235
+
236
+ def start_training(
237
+ dataset_name="",
238
+ exp_name="F5TTS_Base",
239
+ learning_rate=1e-4,
240
+ batch_size_per_gpu=400,
241
+ batch_size_type="frame",
242
+ max_samples=64,
243
+ grad_accumulation_steps=1,
244
+ max_grad_norm=1.0,
245
+ epochs=11,
246
+ num_warmup_updates=200,
247
+ save_per_updates=400,
248
+ last_per_steps=800,
249
+ finetune=True,
250
+ ):
251
+ global training_process, tts_api
252
+
253
+ if tts_api is not None:
254
+ del tts_api
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
257
+ tts_api = None
258
+
259
+ path_project = os.path.join(path_data, dataset_name + "_pinyin")
260
+
261
+ if not os.path.isdir(path_project):
262
+ yield (
263
+ f"There is not project with name {dataset_name}",
264
+ gr.update(interactive=True),
265
+ gr.update(interactive=False),
266
+ )
267
+ return
268
+
269
+ file_raw = os.path.join(path_project, "raw.arrow")
270
+ if not os.path.isfile(file_raw):
271
+ yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
272
+ return
273
+
274
+ # Check if a training process is already running
275
+ if training_process is not None:
276
+ return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
277
+
278
+ yield "start train", gr.update(interactive=False), gr.update(interactive=False)
279
+
280
+ # Command to run the training script with the specified arguments
281
+ cmd = (
282
+ f"accelerate launch finetune-cli.py --exp_name {exp_name} "
283
+ f"--learning_rate {learning_rate} "
284
+ f"--batch_size_per_gpu {batch_size_per_gpu} "
285
+ f"--batch_size_type {batch_size_type} "
286
+ f"--max_samples {max_samples} "
287
+ f"--grad_accumulation_steps {grad_accumulation_steps} "
288
+ f"--max_grad_norm {max_grad_norm} "
289
+ f"--epochs {epochs} "
290
+ f"--num_warmup_updates {num_warmup_updates} "
291
+ f"--save_per_updates {save_per_updates} "
292
+ f"--last_per_steps {last_per_steps} "
293
+ f"--dataset_name {dataset_name}"
294
+ )
295
+ if finetune:
296
+ cmd += f" --finetune {finetune}"
297
+
298
+ print(cmd)
299
+
300
+ try:
301
+ # Start the training process
302
+ training_process = subprocess.Popen(cmd, shell=True)
303
+
304
+ time.sleep(5)
305
+ yield "train start", gr.update(interactive=False), gr.update(interactive=True)
306
+
307
+ # Wait for the training process to finish
308
+ training_process.wait()
309
+ time.sleep(1)
310
+
311
+ if training_process is None:
312
+ text_info = "train stop"
313
+ else:
314
+ text_info = "train complete !"
315
+
316
+ except Exception as e: # Catch all exceptions
317
+ # Ensure that we reset the training process variable in case of an error
318
+ text_info = f"An error occurred: {str(e)}"
319
+
320
+ training_process = None
321
+
322
+ yield text_info, gr.update(interactive=True), gr.update(interactive=False)
323
+
324
+
325
+ def stop_training():
326
+ global training_process
327
+ if training_process is None:
328
+ return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
329
+ terminate_process_tree(training_process.pid)
330
+ training_process = None
331
+ return "train stop", gr.update(interactive=True), gr.update(interactive=False)
332
+
333
+
334
+ def create_data_project(name):
335
+ name += "_pinyin"
336
+ os.makedirs(os.path.join(path_data, name), exist_ok=True)
337
+ os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
338
+
339
+
340
+ def transcribe(file_audio, language="english"):
341
+ global pipe
342
+
343
+ if pipe is None:
344
+ pipe = pipeline(
345
+ "automatic-speech-recognition",
346
+ model="openai/whisper-large-v3-turbo",
347
+ torch_dtype=torch.float16,
348
+ device=device,
349
+ )
350
+
351
+ text_transcribe = pipe(
352
+ file_audio,
353
+ chunk_length_s=30,
354
+ batch_size=128,
355
+ generate_kwargs={"task": "transcribe", "language": language},
356
+ return_timestamps=False,
357
+ )["text"].strip()
358
+ return text_transcribe
359
+
360
+
361
+ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
362
+ name_project += "_pinyin"
363
+ path_project = os.path.join(path_data, name_project)
364
+ path_dataset = os.path.join(path_project, "dataset")
365
+ path_project_wavs = os.path.join(path_project, "wavs")
366
+ file_metadata = os.path.join(path_project, "metadata.csv")
367
+
368
+ if audio_files is None:
369
+ return "You need to load an audio file."
370
+
371
+ if os.path.isdir(path_project_wavs):
372
+ shutil.rmtree(path_project_wavs)
373
+
374
+ if os.path.isfile(file_metadata):
375
+ os.remove(file_metadata)
376
+
377
+ os.makedirs(path_project_wavs, exist_ok=True)
378
+
379
+ if user:
380
+ file_audios = [
381
+ file
382
+ for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
383
+ for file in glob(os.path.join(path_dataset, format))
384
+ ]
385
+ if file_audios == []:
386
+ return "No audio file was found in the dataset."
387
+ else:
388
+ file_audios = audio_files
389
+
390
+ alpha = 0.5
391
+ _max = 1.0
392
+ slicer = Slicer(24000)
393
+
394
+ num = 0
395
+ error_num = 0
396
+ data = ""
397
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
398
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
399
+
400
+ list_slicer = slicer.slice(audio)
401
+ for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
402
+ name_segment = os.path.join(f"segment_{num}")
403
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
404
+
405
+ tmp_max = np.abs(chunk).max()
406
+ if tmp_max > 1:
407
+ chunk /= tmp_max
408
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
409
+ wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
410
+
411
+ try:
412
+ text = transcribe(file_segment, language)
413
+ text = text.lower().strip().replace('"', "")
414
+
415
+ data += f"{name_segment}|{text}\n"
416
+
417
+ num += 1
418
+ except: # noqa: E722
419
+ error_num += 1
420
+
421
+ with open(file_metadata, "w", encoding="utf-8") as f:
422
+ f.write(data)
423
+
424
+ if error_num != []:
425
+ error_text = f"\nerror files : {error_num}"
426
+ else:
427
+ error_text = ""
428
+
429
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
430
+
431
+
432
+ def format_seconds_to_hms(seconds):
433
+ hours = int(seconds / 3600)
434
+ minutes = int((seconds % 3600) / 60)
435
+ seconds = seconds % 60
436
+ return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
437
+
438
+
439
+ def create_metadata(name_project, progress=gr.Progress()):
440
+ name_project += "_pinyin"
441
+ path_project = os.path.join(path_data, name_project)
442
+ path_project_wavs = os.path.join(path_project, "wavs")
443
+ file_metadata = os.path.join(path_project, "metadata.csv")
444
+ file_raw = os.path.join(path_project, "raw.arrow")
445
+ file_duration = os.path.join(path_project, "duration.json")
446
+ file_vocab = os.path.join(path_project, "vocab.txt")
447
+
448
+ if not os.path.isfile(file_metadata):
449
+ return "The file was not found in " + file_metadata
450
+
451
+ with open(file_metadata, "r", encoding="utf-8") as f:
452
+ data = f.read()
453
+
454
+ audio_path_list = []
455
+ text_list = []
456
+ duration_list = []
457
+
458
+ count = data.split("\n")
459
+ lenght = 0
460
+ result = []
461
+ error_files = []
462
+ for line in progress.tqdm(data.split("\n"), total=count):
463
+ sp_line = line.split("|")
464
+ if len(sp_line) != 2:
465
+ continue
466
+ name_audio, text = sp_line[:2]
467
+
468
+ file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
469
+
470
+ if not os.path.isfile(file_audio):
471
+ error_files.append(file_audio)
472
+ continue
473
+
474
+ duraction = get_audio_duration(file_audio)
475
+ if duraction < 2 and duraction > 15:
476
+ continue
477
+ if len(text) < 4:
478
+ continue
479
+
480
+ text = clear_text(text)
481
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
482
+
483
+ audio_path_list.append(file_audio)
484
+ duration_list.append(duraction)
485
+ text_list.append(text)
486
+
487
+ result.append({"audio_path": file_audio, "text": text, "duration": duraction})
488
+
489
+ lenght += duraction
490
+
491
+ if duration_list == []:
492
+ error_files_text = "\n".join(error_files)
493
+ return f"Error: No audio files found in the specified path : \n{error_files_text}"
494
+
495
+ min_second = round(min(duration_list), 2)
496
+ max_second = round(max(duration_list), 2)
497
+
498
+ with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
499
+ for line in progress.tqdm(result, total=len(result), desc="prepare data"):
500
+ writer.write(line)
501
+
502
+ with open(file_duration, "w", encoding="utf-8") as f:
503
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
504
+
505
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
506
+ if not os.path.isfile(file_vocab_finetune):
507
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
508
+ shutil.copy2(file_vocab_finetune, file_vocab)
509
+
510
+ if error_files != []:
511
+ error_text = "error files\n" + "\n".join(error_files)
512
+ else:
513
+ error_text = ""
514
+
515
+ return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
516
+
517
+
518
+ def check_user(value):
519
+ return gr.update(visible=not value), gr.update(visible=value)
520
+
521
+
522
+ def calculate_train(
523
+ name_project,
524
+ batch_size_type,
525
+ max_samples,
526
+ learning_rate,
527
+ num_warmup_updates,
528
+ save_per_updates,
529
+ last_per_steps,
530
+ finetune,
531
+ ):
532
+ name_project += "_pinyin"
533
+ path_project = os.path.join(path_data, name_project)
534
+ file_duraction = os.path.join(path_project, "duration.json")
535
+
536
+ if not os.path.isfile(file_duraction):
537
+ return (
538
+ 1000,
539
+ max_samples,
540
+ num_warmup_updates,
541
+ save_per_updates,
542
+ last_per_steps,
543
+ "project not found !",
544
+ learning_rate,
545
+ )
546
+
547
+ with open(file_duraction, "r") as file:
548
+ data = json.load(file)
549
+
550
+ duration_list = data["duration"]
551
+
552
+ samples = len(duration_list)
553
+
554
+ if torch.cuda.is_available():
555
+ gpu_properties = torch.cuda.get_device_properties(0)
556
+ total_memory = gpu_properties.total_memory / (1024**3)
557
+ elif torch.backends.mps.is_available():
558
+ total_memory = psutil.virtual_memory().available / (1024**3)
559
+
560
+ if batch_size_type == "frame":
561
+ batch = int(total_memory * 0.5)
562
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
563
+ batch_size_per_gpu = int(38400 / batch)
564
+ else:
565
+ batch_size_per_gpu = int(total_memory / 8)
566
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
567
+ batch = batch_size_per_gpu
568
+
569
+ if batch_size_per_gpu <= 0:
570
+ batch_size_per_gpu = 1
571
+
572
+ if samples < 64:
573
+ max_samples = int(samples * 0.25)
574
+ else:
575
+ max_samples = 64
576
+
577
+ num_warmup_updates = int(samples * 0.05)
578
+ save_per_updates = int(samples * 0.10)
579
+ last_per_steps = int(save_per_updates * 5)
580
+
581
+ max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
582
+ num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
583
+ save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
584
+ last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
585
+
586
+ if finetune:
587
+ learning_rate = 1e-5
588
+ else:
589
+ learning_rate = 7.5e-5
590
+
591
+ return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
592
+
593
+
594
+ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
595
+ try:
596
+ checkpoint = torch.load(checkpoint_path)
597
+ print("Original Checkpoint Keys:", checkpoint.keys())
598
+
599
+ ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
600
+
601
+ if ema_model_state_dict is not None:
602
+ new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
603
+ torch.save(new_checkpoint, new_checkpoint_path)
604
+ return f"New checkpoint saved at: {new_checkpoint_path}"
605
+ else:
606
+ return "No 'ema_model_state_dict' found in the checkpoint."
607
+
608
+ except Exception as e:
609
+ return f"An error occurred: {e}"
610
+
611
+
612
+ def vocab_check(project_name):
613
+ name_project = project_name + "_pinyin"
614
+ path_project = os.path.join(path_data, name_project)
615
+
616
+ file_metadata = os.path.join(path_project, "metadata.csv")
617
+
618
+ file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
619
+ if not os.path.isfile(file_vocab):
620
+ return f"the file {file_vocab} not found !"
621
+
622
+ with open(file_vocab, "r", encoding="utf-8") as f:
623
+ data = f.read()
624
+
625
+ vocab = data.split("\n")
626
+
627
+ if not os.path.isfile(file_metadata):
628
+ return f"the file {file_metadata} not found !"
629
+
630
+ with open(file_metadata, "r", encoding="utf-8") as f:
631
+ data = f.read()
632
+
633
+ miss_symbols = []
634
+ miss_symbols_keep = {}
635
+ for item in data.split("\n"):
636
+ sp = item.split("|")
637
+ if len(sp) != 2:
638
+ continue
639
+
640
+ text = sp[1].lower().strip()
641
+
642
+ for t in text:
643
+ if t not in vocab and t not in miss_symbols_keep:
644
+ miss_symbols.append(t)
645
+ miss_symbols_keep[t] = t
646
+ if miss_symbols == []:
647
+ info = "You can train using your language !"
648
+ else:
649
+ info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
650
+
651
+ return info
652
+
653
+
654
+ def get_random_sample_prepare(project_name):
655
+ name_project = project_name + "_pinyin"
656
+ path_project = os.path.join(path_data, name_project)
657
+ file_arrow = os.path.join(path_project, "raw.arrow")
658
+ if not os.path.isfile(file_arrow):
659
+ return "", None
660
+ dataset = Dataset_.from_file(file_arrow)
661
+ random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
662
+ text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
663
+ audio_path = random_sample["audio_path"][0]
664
+ return text, audio_path
665
+
666
+
667
+ def get_random_sample_transcribe(project_name):
668
+ name_project = project_name + "_pinyin"
669
+ path_project = os.path.join(path_data, name_project)
670
+ file_metadata = os.path.join(path_project, "metadata.csv")
671
+ if not os.path.isfile(file_metadata):
672
+ return "", None
673
+
674
+ data = ""
675
+ with open(file_metadata, "r", encoding="utf-8") as f:
676
+ data = f.read()
677
+
678
+ list_data = []
679
+ for item in data.split("\n"):
680
+ sp = item.split("|")
681
+ if len(sp) != 2:
682
+ continue
683
+ list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
684
+
685
+ if list_data == []:
686
+ return "", None
687
+
688
+ random_item = random.choice(list_data)
689
+
690
+ return random_item[1], random_item[0]
691
+
692
+
693
+ def get_random_sample_infer(project_name):
694
+ text, audio = get_random_sample_transcribe(project_name)
695
+ return (
696
+ text,
697
+ text,
698
+ audio,
699
+ )
700
+
701
+
702
+ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
703
+ global last_checkpoint, last_device, tts_api
704
+
705
+ if not os.path.isfile(file_checkpoint):
706
+ return None
707
+
708
+ if training_process is not None:
709
+ device_test = "cpu"
710
+ else:
711
+ device_test = None
712
+
713
+ if last_checkpoint != file_checkpoint or last_device != device_test:
714
+ if last_checkpoint != file_checkpoint:
715
+ last_checkpoint = file_checkpoint
716
+ if last_device != device_test:
717
+ last_device = device_test
718
+
719
+ tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
720
+
721
+ print("update", device_test, file_checkpoint)
722
+
723
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
724
+ tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
725
+ return f.name
726
+
727
+
728
+ with gr.Blocks() as app:
729
+ with gr.Row():
730
+ project_name = gr.Textbox(label="project name", value="my_speak")
731
+ bt_create = gr.Button("create new project")
732
+
733
+ bt_create.click(fn=create_data_project, inputs=[project_name])
734
+
735
+ with gr.Tabs():
736
+ with gr.TabItem("transcribe Data"):
737
+ ch_manual = gr.Checkbox(label="user", value=False)
738
+
739
+ mark_info_transcribe = gr.Markdown(
740
+ """```plaintext
741
+ Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
742
+
743
+ my_speak/
744
+ β”‚
745
+ └── dataset/
746
+ β”œβ”€β”€ audio1.wav
747
+ └── audio2.wav
748
+ ...
749
+ ```""",
750
+ visible=False,
751
+ )
752
+
753
+ audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
754
+ txt_lang = gr.Text(label="Language", value="english")
755
+ bt_transcribe = bt_create = gr.Button("transcribe")
756
+ txt_info_transcribe = gr.Text(label="info", value="")
757
+ bt_transcribe.click(
758
+ fn=transcribe_all,
759
+ inputs=[project_name, audio_speaker, txt_lang, ch_manual],
760
+ outputs=[txt_info_transcribe],
761
+ )
762
+ ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
763
+
764
+ random_sample_transcribe = gr.Button("random sample")
765
+
766
+ with gr.Row():
767
+ random_text_transcribe = gr.Text(label="Text")
768
+ random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
769
+
770
+ random_sample_transcribe.click(
771
+ fn=get_random_sample_transcribe,
772
+ inputs=[project_name],
773
+ outputs=[random_text_transcribe, random_audio_transcribe],
774
+ )
775
+
776
+ with gr.TabItem("prepare Data"):
777
+ gr.Markdown(
778
+ """```plaintext
779
+ place all your wavs folder and your metadata.csv file in {your name project}
780
+ my_speak/
781
+ β”‚
782
+ β”œβ”€β”€ wavs/
783
+ β”‚ β”œβ”€β”€ audio1.wav
784
+ β”‚ └── audio2.wav
785
+ | ...
786
+ β”‚
787
+ └── metadata.csv
788
+
789
+ file format metadata.csv
790
+
791
+ audio1|text1
792
+ audio2|text1
793
+ ...
794
+
795
+ ```"""
796
+ )
797
+
798
+ bt_prepare = bt_create = gr.Button("prepare")
799
+ txt_info_prepare = gr.Text(label="info", value="")
800
+ bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
801
+
802
+ random_sample_prepare = gr.Button("random sample")
803
+
804
+ with gr.Row():
805
+ random_text_prepare = gr.Text(label="Pinyin")
806
+ random_audio_prepare = gr.Audio(label="Audio", type="filepath")
807
+
808
+ random_sample_prepare.click(
809
+ fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
810
+ )
811
+
812
+ with gr.TabItem("train Data"):
813
+ with gr.Row():
814
+ bt_calculate = bt_create = gr.Button("Auto Settings")
815
+ ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
816
+ lb_samples = gr.Label(label="samples")
817
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
818
+
819
+ with gr.Row():
820
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
821
+ learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
822
+
823
+ with gr.Row():
824
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
825
+ max_samples = gr.Number(label="Max Samples", value=64)
826
+
827
+ with gr.Row():
828
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
829
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
830
+
831
+ with gr.Row():
832
+ epochs = gr.Number(label="Epochs", value=10)
833
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
834
+
835
+ with gr.Row():
836
+ save_per_updates = gr.Number(label="Save per Updates", value=10)
837
+ last_per_steps = gr.Number(label="Last per Steps", value=50)
838
+
839
+ with gr.Row():
840
+ start_button = gr.Button("Start Training")
841
+ stop_button = gr.Button("Stop Training", interactive=False)
842
+
843
+ txt_info_train = gr.Text(label="info", value="")
844
+ start_button.click(
845
+ fn=start_training,
846
+ inputs=[
847
+ project_name,
848
+ exp_name,
849
+ learning_rate,
850
+ batch_size_per_gpu,
851
+ batch_size_type,
852
+ max_samples,
853
+ grad_accumulation_steps,
854
+ max_grad_norm,
855
+ epochs,
856
+ num_warmup_updates,
857
+ save_per_updates,
858
+ last_per_steps,
859
+ ch_finetune,
860
+ ],
861
+ outputs=[txt_info_train, start_button, stop_button],
862
+ )
863
+ stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
864
+ bt_calculate.click(
865
+ fn=calculate_train,
866
+ inputs=[
867
+ project_name,
868
+ batch_size_type,
869
+ max_samples,
870
+ learning_rate,
871
+ num_warmup_updates,
872
+ save_per_updates,
873
+ last_per_steps,
874
+ ch_finetune,
875
+ ],
876
+ outputs=[
877
+ batch_size_per_gpu,
878
+ max_samples,
879
+ num_warmup_updates,
880
+ save_per_updates,
881
+ last_per_steps,
882
+ lb_samples,
883
+ learning_rate,
884
+ ],
885
+ )
886
+
887
+ with gr.TabItem("reduse checkpoint"):
888
+ txt_path_checkpoint = gr.Text(label="path checkpoint :")
889
+ txt_path_checkpoint_small = gr.Text(label="path output :")
890
+ txt_info_reduse = gr.Text(label="info", value="")
891
+ reduse_button = gr.Button("reduse")
892
+ reduse_button.click(
893
+ fn=extract_and_save_ema_model,
894
+ inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
895
+ outputs=[txt_info_reduse],
896
+ )
897
+
898
+ with gr.TabItem("vocab check experiment"):
899
+ check_button = gr.Button("check vocab")
900
+ txt_info_check = gr.Text(label="info", value="")
901
+ check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
902
+
903
+ with gr.TabItem("test model"):
904
+ exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
905
+ nfe_step = gr.Number(label="n_step", value=32)
906
+ file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
907
+
908
+ random_sample_infer = gr.Button("random sample")
909
+
910
+ ref_text = gr.Textbox(label="ref text")
911
+ ref_audio = gr.Audio(label="audio ref", type="filepath")
912
+ gen_text = gr.Textbox(label="gen text")
913
+ random_sample_infer.click(
914
+ fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
915
+ )
916
+ check_button_infer = gr.Button("infer")
917
+ gen_audio = gr.Audio(label="audio gen", type="filepath")
918
+
919
+ check_button_infer.click(
920
+ fn=infer,
921
+ inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
922
+ outputs=[gen_audio],
923
+ )
924
+
925
+
926
+ @click.command()
927
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
928
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
929
+ @click.option(
930
+ "--share",
931
+ "-s",
932
+ default=False,
933
+ is_flag=True,
934
+ help="Share the app via Gradio share link",
935
+ )
936
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
937
+ def main(port, host, share, api):
938
+ global app
939
+ print("Starting app...")
940
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
941
+
942
+
943
+ if __name__ == "__main__":
944
+ main()