mrfakename commited on
Commit
b6584c2
1 Parent(s): 0cc615c

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

app.py CHANGED
@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
37
  save_spectrogram,
38
  )
39
 
40
- vocos = load_vocoder()
41
 
42
 
43
  # load models
@@ -94,6 +94,7 @@ def infer(
94
  ref_text,
95
  gen_text,
96
  ema_model,
 
97
  cross_fade_duration=cross_fade_duration,
98
  speed=speed,
99
  show_info=show_info,
 
37
  save_spectrogram,
38
  )
39
 
40
+ vocoder = load_vocoder()
41
 
42
 
43
  # load models
 
94
  ref_text,
95
  gen_text,
96
  ema_model,
97
+ vocoder,
98
  cross_fade_duration=cross_fade_duration,
99
  speed=speed,
100
  show_info=show_info,
src/f5_tts/api.py CHANGED
@@ -47,7 +47,7 @@ class F5TTS:
47
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
48
 
49
  def load_vocoder_model(self, local_path):
50
- self.vocos = load_vocoder(local_path is not None, local_path, self.device)
51
 
52
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
53
  if model_type == "F5-TTS":
@@ -102,6 +102,7 @@ class F5TTS:
102
  ref_text,
103
  gen_text,
104
  self.ema_model,
 
105
  show_info=show_info,
106
  progress=progress,
107
  target_rms=target_rms,
 
47
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
48
 
49
  def load_vocoder_model(self, local_path):
50
+ self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
51
 
52
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
53
  if model_type == "F5-TTS":
 
102
  ref_text,
103
  gen_text,
104
  self.ema_model,
105
+ self.vocoder,
106
  show_info=show_info,
107
  progress=progress,
108
  target_rms=target_rms,
src/f5_tts/infer/infer_cli.py CHANGED
@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
115
 
116
- vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
117
 
118
 
119
  # load models
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
175
  ref_audio = voices[voice]["ref_audio"]
176
  ref_text = voices[voice]["ref_text"]
177
  print(f"Voice: {voice}")
178
- audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
 
 
179
  generated_audio_segments.append(audio)
180
 
181
  if generated_audio_segments:
 
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
115
 
116
+ vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
117
 
118
 
119
  # load models
 
175
  ref_audio = voices[voice]["ref_audio"]
176
  ref_text = voices[voice]["ref_text"]
177
  print(f"Voice: {voice}")
178
+ audio, final_sample_rate, spectragram = infer_process(
179
+ ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
180
+ )
181
  generated_audio_segments.append(audio)
182
 
183
  if generated_audio_segments:
src/f5_tts/infer/utils_infer.py CHANGED
@@ -29,9 +29,6 @@ _ref_audio_cache = {}
29
 
30
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
31
 
32
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
33
-
34
-
35
  # -----------------------------------------
36
 
37
  target_sample_rate = 24000
@@ -263,6 +260,7 @@ def infer_process(
263
  ref_text,
264
  gen_text,
265
  model_obj,
 
266
  show_info=print,
267
  progress=tqdm,
268
  target_rms=target_rms,
@@ -287,6 +285,7 @@ def infer_process(
287
  ref_text,
288
  gen_text_batches,
289
  model_obj,
 
290
  progress=progress,
291
  target_rms=target_rms,
292
  cross_fade_duration=cross_fade_duration,
@@ -307,6 +306,7 @@ def infer_batch_process(
307
  ref_text,
308
  gen_text_batches,
309
  model_obj,
 
310
  progress=tqdm,
311
  target_rms=0.1,
312
  cross_fade_duration=0.15,
@@ -362,7 +362,7 @@ def infer_batch_process(
362
  generated = generated.to(torch.float32)
363
  generated = generated[:, ref_audio_len:, :]
364
  generated_mel_spec = generated.permute(0, 2, 1)
365
- generated_wave = vocos.decode(generated_mel_spec.cpu())
366
  if rms < target_rms:
367
  generated_wave = generated_wave * rms / target_rms
368
 
 
29
 
30
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
31
 
 
 
 
32
  # -----------------------------------------
33
 
34
  target_sample_rate = 24000
 
260
  ref_text,
261
  gen_text,
262
  model_obj,
263
+ vocoder,
264
  show_info=print,
265
  progress=tqdm,
266
  target_rms=target_rms,
 
285
  ref_text,
286
  gen_text_batches,
287
  model_obj,
288
+ vocoder,
289
  progress=progress,
290
  target_rms=target_rms,
291
  cross_fade_duration=cross_fade_duration,
 
306
  ref_text,
307
  gen_text_batches,
308
  model_obj,
309
+ vocoder,
310
  progress=tqdm,
311
  target_rms=0.1,
312
  cross_fade_duration=0.15,
 
362
  generated = generated.to(torch.float32)
363
  generated = generated[:, ref_audio_len:, :]
364
  generated_mel_spec = generated.permute(0, 2, 1)
365
+ generated_wave = vocoder.decode(generated_mel_spec.cpu())
366
  if rms < target_rms:
367
  generated_wave = generated_wave * rms / target_rms
368
 
src/f5_tts/model/trainer.py CHANGED
@@ -6,6 +6,7 @@ from tqdm import tqdm
6
  import wandb
7
 
8
  import torch
 
9
  from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
@@ -39,9 +40,11 @@ class Trainer:
39
  max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
 
42
  wandb_project="test_e2-tts",
43
  wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
 
45
  last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
@@ -49,21 +52,25 @@ class Trainer:
49
  ):
50
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
 
52
- logger = "wandb" if wandb.api.api_key else None
 
53
  print(f"Using logger: {logger}")
 
54
 
55
  self.accelerator = Accelerator(
56
- log_with=logger,
57
  kwargs_handlers=[ddp_kwargs],
58
  gradient_accumulation_steps=grad_accumulation_steps,
59
  **accelerate_kwargs,
60
  )
61
 
62
- if logger == "wandb":
 
63
  if exists(wandb_resume_id):
64
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
  else:
66
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
 
67
  self.accelerator.init_trackers(
68
  project_name=wandb_project,
69
  init_kwargs=init_kwargs,
@@ -81,11 +88,15 @@ class Trainer:
81
  },
82
  )
83
 
 
 
 
 
 
84
  self.model = model
85
 
86
  if self.is_main:
87
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
88
-
89
  self.ema_model.to(self.accelerator.device)
90
 
91
  self.epochs = epochs
@@ -176,6 +187,14 @@ class Trainer:
176
  return step
177
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
 
 
 
 
 
 
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
@@ -286,12 +305,31 @@ class Trainer:
286
 
287
  if self.accelerator.is_local_main_process:
288
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
 
 
 
289
 
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
 
292
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
293
  self.save_checkpoint(global_step)
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  if global_step % self.last_per_steps == 0:
296
  self.save_checkpoint(global_step, last=True)
297
 
 
6
  import wandb
7
 
8
  import torch
9
+ import torchaudio
10
  from torch.optim import AdamW
11
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
12
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
 
40
  max_grad_norm=1.0,
41
  noise_scheduler: str | None = None,
42
  duration_predictor: torch.nn.Module | None = None,
43
+ logger: str | None = "wandb", # "wandb" | "tensorboard" | None
44
  wandb_project="test_e2-tts",
45
  wandb_run_name="test_run",
46
  wandb_resume_id: str = None,
47
+ log_samples: bool = False,
48
  last_per_steps=None,
49
  accelerate_kwargs: dict = dict(),
50
  ema_kwargs: dict = dict(),
 
52
  ):
53
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
 
55
+ if logger == "wandb" and not wandb.api.api_key:
56
+ logger = None
57
  print(f"Using logger: {logger}")
58
+ self.log_samples = log_samples
59
 
60
  self.accelerator = Accelerator(
61
+ log_with=logger if logger == "wandb" else None,
62
  kwargs_handlers=[ddp_kwargs],
63
  gradient_accumulation_steps=grad_accumulation_steps,
64
  **accelerate_kwargs,
65
  )
66
 
67
+ self.logger = logger
68
+ if self.logger == "wandb":
69
  if exists(wandb_resume_id):
70
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
71
  else:
72
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
73
+
74
  self.accelerator.init_trackers(
75
  project_name=wandb_project,
76
  init_kwargs=init_kwargs,
 
88
  },
89
  )
90
 
91
+ elif self.logger == "tensorboard":
92
+ from torch.utils.tensorboard import SummaryWriter
93
+
94
+ self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
95
+
96
  self.model = model
97
 
98
  if self.is_main:
99
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
100
  self.ema_model.to(self.accelerator.device)
101
 
102
  self.epochs = epochs
 
187
  return step
188
 
189
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
190
+ if self.log_samples:
191
+ from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
192
+
193
+ vocoder = load_vocoder()
194
+ target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
195
+ log_samples_path = f"{self.checkpoint_path}/samples"
196
+ os.makedirs(log_samples_path, exist_ok=True)
197
+
198
  if exists(resumable_with_seed):
199
  generator = torch.Generator()
200
  generator.manual_seed(resumable_with_seed)
 
305
 
306
  if self.accelerator.is_local_main_process:
307
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
308
+ if self.logger == "tensorboard":
309
+ self.writer.add_scalar("loss", loss.item(), global_step)
310
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
311
 
312
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
313
 
314
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
315
  self.save_checkpoint(global_step)
316
 
317
+ if self.log_samples and self.accelerator.is_local_main_process:
318
+ ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
319
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
320
+ with torch.inference_mode():
321
+ generated, _ = self.accelerator.unwrap_model(self.model).sample(
322
+ cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
323
+ text=[text_inputs[0] + [" "] + text_inputs[0]],
324
+ duration=ref_audio_len * 2,
325
+ steps=nfe_step,
326
+ cfg_strength=cfg_strength,
327
+ sway_sampling_coef=sway_sampling_coef,
328
+ )
329
+ generated = generated.to(torch.float32)
330
+ gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
331
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
332
+
333
  if global_step % self.last_per_steps == 0:
334
  self.save_checkpoint(global_step, last=True)
335
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -56,6 +56,14 @@ def parse_args():
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
 
 
 
 
 
 
 
 
59
  return parser.parse_args()
60
 
61
 
@@ -64,6 +72,7 @@ def parse_args():
64
 
65
  def main():
66
  args = parse_args()
 
67
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
68
 
69
  # Model parameters based on experiment name
@@ -132,9 +141,11 @@ def main():
132
  max_samples=args.max_samples,
133
  grad_accumulation_steps=args.grad_accumulation_steps,
134
  max_grad_norm=args.max_grad_norm,
 
135
  wandb_project=args.dataset_name,
136
  wandb_run_name=args.exp_name,
137
  wandb_resume_id=wandb_resume_id,
 
138
  last_per_steps=args.last_per_steps,
139
  )
140
 
 
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
59
+ parser.add_argument(
60
+ "--log_samples",
61
+ type=bool,
62
+ default=False,
63
+ help="Log inferenced samples per ckpt save steps",
64
+ )
65
+ parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
66
+
67
  return parser.parse_args()
68
 
69
 
 
72
 
73
  def main():
74
  args = parse_args()
75
+
76
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
77
 
78
  # Model parameters based on experiment name
 
141
  max_samples=args.max_samples,
142
  grad_accumulation_steps=args.grad_accumulation_steps,
143
  max_grad_norm=args.max_grad_norm,
144
+ logger=args.logger,
145
  wandb_project=args.dataset_name,
146
  wandb_run_name=args.exp_name,
147
  wandb_resume_id=wandb_resume_id,
148
+ log_samples=args.log_samples,
149
  last_per_steps=args.last_per_steps,
150
  )
151
 
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -69,6 +69,7 @@ def save_settings(
69
  tokenizer_type,
70
  tokenizer_file,
71
  mixed_precision,
 
72
  ):
73
  path_project = os.path.join(path_project_ckpts, project_name)
74
  os.makedirs(path_project, exist_ok=True)
@@ -91,6 +92,7 @@ def save_settings(
91
  "tokenizer_type": tokenizer_type,
92
  "tokenizer_file": tokenizer_file,
93
  "mixed_precision": mixed_precision,
 
94
  }
95
  with open(file_setting, "w") as f:
96
  json.dump(settings, f, indent=4)
@@ -121,6 +123,7 @@ def load_settings(project_name):
121
  "tokenizer_type": "pinyin",
122
  "tokenizer_file": "",
123
  "mixed_precision": "none",
 
124
  }
125
  return (
126
  settings["exp_name"],
@@ -139,6 +142,7 @@ def load_settings(project_name):
139
  settings["tokenizer_type"],
140
  settings["tokenizer_file"],
141
  settings["mixed_precision"],
 
142
  )
143
 
144
  with open(file_setting, "r") as f:
@@ -160,6 +164,7 @@ def load_settings(project_name):
160
  settings["tokenizer_type"],
161
  settings["tokenizer_file"],
162
  settings["mixed_precision"],
 
163
  )
164
 
165
 
@@ -374,6 +379,7 @@ def start_training(
374
  tokenizer_file="",
375
  mixed_precision="fp16",
376
  stream=False,
 
377
  ):
378
  global training_process, tts_api, stop_signal
379
 
@@ -447,6 +453,8 @@ def start_training(
447
 
448
  cmd += f" --tokenizer {tokenizer_type} "
449
 
 
 
450
  print(cmd)
451
 
452
  save_settings(
@@ -467,6 +475,7 @@ def start_training(
467
  tokenizer_type,
468
  tokenizer_file,
469
  mixed_precision,
 
470
  )
471
 
472
  try:
@@ -1223,6 +1232,27 @@ def get_checkpoints_project(project_name, is_gradio=True):
1223
  return files_checkpoints, selelect_checkpoint
1224
 
1225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1226
  def get_gpu_stats():
1227
  gpu_stats = ""
1228
 
@@ -1290,6 +1320,17 @@ def get_combined_stats():
1290
  return combined_stats
1291
 
1292
 
 
 
 
 
 
 
 
 
 
 
 
1293
  with gr.Blocks() as app:
1294
  gr.Markdown(
1295
  """
@@ -1470,6 +1511,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1470
 
1471
  with gr.Row():
1472
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
 
1473
  start_button = gr.Button("Start Training")
1474
  stop_button = gr.Button("Stop Training", interactive=False)
1475
 
@@ -1491,6 +1533,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1491
  tokenizer_typev,
1492
  tokenizer_filev,
1493
  mixed_precisionv,
 
1494
  ) = load_settings(projects_selelect)
1495
  exp_name.value = exp_namev
1496
  learning_rate.value = learning_ratev
@@ -1508,9 +1551,43 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1508
  tokenizer_type.value = tokenizer_typev
1509
  tokenizer_file.value = tokenizer_filev
1510
  mixed_precision.value = mixed_precisionv
 
1511
 
1512
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1513
  txt_info_train = gr.Text(label="info", value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1514
  start_button.click(
1515
  fn=start_training,
1516
  inputs=[
@@ -1532,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1532
  tokenizer_file,
1533
  mixed_precision,
1534
  ch_stream,
 
1535
  ],
1536
  outputs=[txt_info_train, start_button, stop_button],
1537
  )
@@ -1583,6 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1583
  tokenizer_type,
1584
  tokenizer_file,
1585
  mixed_precision,
 
1586
  ]
1587
 
1588
  return output_components
 
69
  tokenizer_type,
70
  tokenizer_file,
71
  mixed_precision,
72
+ logger,
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
 
92
  "tokenizer_type": tokenizer_type,
93
  "tokenizer_file": tokenizer_file,
94
  "mixed_precision": mixed_precision,
95
+ "logger": logger,
96
  }
97
  with open(file_setting, "w") as f:
98
  json.dump(settings, f, indent=4)
 
123
  "tokenizer_type": "pinyin",
124
  "tokenizer_file": "",
125
  "mixed_precision": "none",
126
+ "logger": "wandb",
127
  }
128
  return (
129
  settings["exp_name"],
 
142
  settings["tokenizer_type"],
143
  settings["tokenizer_file"],
144
  settings["mixed_precision"],
145
+ settings["logger"],
146
  )
147
 
148
  with open(file_setting, "r") as f:
 
164
  settings["tokenizer_type"],
165
  settings["tokenizer_file"],
166
  settings["mixed_precision"],
167
+ settings["logger"],
168
  )
169
 
170
 
 
379
  tokenizer_file="",
380
  mixed_precision="fp16",
381
  stream=False,
382
+ logger="wandb",
383
  ):
384
  global training_process, tts_api, stop_signal
385
 
 
453
 
454
  cmd += f" --tokenizer {tokenizer_type} "
455
 
456
+ cmd += f" --log_samples True --logger {logger} "
457
+
458
  print(cmd)
459
 
460
  save_settings(
 
475
  tokenizer_type,
476
  tokenizer_file,
477
  mixed_precision,
478
+ logger,
479
  )
480
 
481
  try:
 
1232
  return files_checkpoints, selelect_checkpoint
1233
 
1234
 
1235
+ def get_audio_project(project_name, is_gradio=True):
1236
+ if project_name is None:
1237
+ return [], ""
1238
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1239
+
1240
+ if os.path.isdir(path_project_ckpts):
1241
+ files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
1242
+ files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
1243
+
1244
+ files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
1245
+ else:
1246
+ files_audios = []
1247
+
1248
+ selelect_checkpoint = None if not files_audios else files_audios[0]
1249
+
1250
+ if is_gradio:
1251
+ return gr.update(choices=files_audios, value=selelect_checkpoint)
1252
+
1253
+ return files_audios, selelect_checkpoint
1254
+
1255
+
1256
  def get_gpu_stats():
1257
  gpu_stats = ""
1258
 
 
1320
  return combined_stats
1321
 
1322
 
1323
+ def get_audio_select(file_sample):
1324
+ select_audio_ref = file_sample
1325
+ select_audio_gen = file_sample
1326
+
1327
+ if file_sample is not None:
1328
+ select_audio_ref += "_ref.wav"
1329
+ select_audio_gen += "_gen.wav"
1330
+
1331
+ return select_audio_ref, select_audio_gen
1332
+
1333
+
1334
  with gr.Blocks() as app:
1335
  gr.Markdown(
1336
  """
 
1511
 
1512
  with gr.Row():
1513
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1514
+ cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1515
  start_button = gr.Button("Start Training")
1516
  stop_button = gr.Button("Stop Training", interactive=False)
1517
 
 
1533
  tokenizer_typev,
1534
  tokenizer_filev,
1535
  mixed_precisionv,
1536
+ cd_loggerv,
1537
  ) = load_settings(projects_selelect)
1538
  exp_name.value = exp_namev
1539
  learning_rate.value = learning_ratev
 
1551
  tokenizer_type.value = tokenizer_typev
1552
  tokenizer_file.value = tokenizer_filev
1553
  mixed_precision.value = mixed_precisionv
1554
+ cd_logger.value = cd_loggerv
1555
 
1556
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1557
  txt_info_train = gr.Text(label="info", value="")
1558
+
1559
+ list_audios, select_audio = get_audio_project(projects_selelect, False)
1560
+
1561
+ select_audio_ref = select_audio
1562
+ select_audio_gen = select_audio
1563
+
1564
+ if select_audio is not None:
1565
+ select_audio_ref += "_ref.wav"
1566
+ select_audio_gen += "_gen.wav"
1567
+
1568
+ with gr.Row():
1569
+ ch_list_audio = gr.Dropdown(
1570
+ choices=list_audios,
1571
+ value=select_audio,
1572
+ label="audios",
1573
+ allow_custom_value=True,
1574
+ scale=6,
1575
+ interactive=True,
1576
+ )
1577
+ bt_stream_audio = gr.Button("refresh", scale=1)
1578
+ bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1579
+ cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1580
+
1581
+ with gr.Row():
1582
+ audio_ref_stream = gr.Audio(label="original", type="filepath", value=select_audio_ref)
1583
+ audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
1584
+
1585
+ ch_list_audio.change(
1586
+ fn=get_audio_select,
1587
+ inputs=[ch_list_audio],
1588
+ outputs=[audio_ref_stream, audio_gen_stream],
1589
+ )
1590
+
1591
  start_button.click(
1592
  fn=start_training,
1593
  inputs=[
 
1609
  tokenizer_file,
1610
  mixed_precision,
1611
  ch_stream,
1612
+ cd_logger,
1613
  ],
1614
  outputs=[txt_info_train, start_button, stop_button],
1615
  )
 
1661
  tokenizer_type,
1662
  tokenizer_file,
1663
  mixed_precision,
1664
+ cd_logger,
1665
  ]
1666
 
1667
  return output_components
src/f5_tts/train/train.py CHANGED
@@ -83,6 +83,7 @@ def main():
83
  wandb_run_name=exp_name,
84
  wandb_resume_id=wandb_resume_id,
85
  last_per_steps=last_per_steps,
 
86
  )
87
 
88
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
83
  wandb_run_name=exp_name,
84
  wandb_resume_id=wandb_resume_id,
85
  last_per_steps=last_per_steps,
86
+ log_samples=True,
87
  )
88
 
89
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)