Spaces:
Paused
Paused
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/model/trainer.py
CHANGED
|
@@ -47,6 +47,8 @@ class Trainer:
|
|
| 47 |
ema_kwargs: dict = dict(),
|
| 48 |
bnb_optimizer: bool = False,
|
| 49 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
|
|
|
|
|
|
| 50 |
):
|
| 51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 52 |
|
|
@@ -108,7 +110,11 @@ class Trainer:
|
|
| 108 |
self.max_samples = max_samples
|
| 109 |
self.grad_accumulation_steps = grad_accumulation_steps
|
| 110 |
self.max_grad_norm = max_grad_norm
|
|
|
|
|
|
|
| 111 |
self.vocoder_name = mel_spec_type
|
|
|
|
|
|
|
| 112 |
|
| 113 |
self.noise_scheduler = noise_scheduler
|
| 114 |
|
|
@@ -199,7 +205,9 @@ class Trainer:
|
|
| 199 |
if self.log_samples:
|
| 200 |
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
|
| 201 |
|
| 202 |
-
vocoder = load_vocoder(
|
|
|
|
|
|
|
| 203 |
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
|
| 204 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
| 205 |
os.makedirs(log_samples_path, exist_ok=True)
|
|
|
|
| 47 |
ema_kwargs: dict = dict(),
|
| 48 |
bnb_optimizer: bool = False,
|
| 49 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
| 50 |
+
is_local_vocoder: bool = False, # use local path vocoder
|
| 51 |
+
local_vocoder_path: str = "", # local vocoder path
|
| 52 |
):
|
| 53 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 54 |
|
|
|
|
| 110 |
self.max_samples = max_samples
|
| 111 |
self.grad_accumulation_steps = grad_accumulation_steps
|
| 112 |
self.max_grad_norm = max_grad_norm
|
| 113 |
+
|
| 114 |
+
# mel vocoder config
|
| 115 |
self.vocoder_name = mel_spec_type
|
| 116 |
+
self.is_local_vocoder = is_local_vocoder
|
| 117 |
+
self.local_vocoder_path = local_vocoder_path
|
| 118 |
|
| 119 |
self.noise_scheduler = noise_scheduler
|
| 120 |
|
|
|
|
| 205 |
if self.log_samples:
|
| 206 |
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
|
| 207 |
|
| 208 |
+
vocoder = load_vocoder(
|
| 209 |
+
vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
|
| 210 |
+
)
|
| 211 |
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
|
| 212 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
| 213 |
os.makedirs(log_samples_path, exist_ok=True)
|