Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Commit
•
cf68f41
1
Parent(s):
57b3db8
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/model/trainer.py
CHANGED
@@ -47,6 +47,8 @@ class Trainer:
|
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
|
|
|
|
50 |
):
|
51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
52 |
|
@@ -108,7 +110,11 @@ class Trainer:
|
|
108 |
self.max_samples = max_samples
|
109 |
self.grad_accumulation_steps = grad_accumulation_steps
|
110 |
self.max_grad_norm = max_grad_norm
|
|
|
|
|
111 |
self.vocoder_name = mel_spec_type
|
|
|
|
|
112 |
|
113 |
self.noise_scheduler = noise_scheduler
|
114 |
|
@@ -199,7 +205,9 @@ class Trainer:
|
|
199 |
if self.log_samples:
|
200 |
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
|
201 |
|
202 |
-
vocoder = load_vocoder(
|
|
|
|
|
203 |
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
|
204 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
205 |
os.makedirs(log_samples_path, exist_ok=True)
|
|
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
50 |
+
is_local_vocoder: bool = False, # use local path vocoder
|
51 |
+
local_vocoder_path: str = "", # local vocoder path
|
52 |
):
|
53 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
54 |
|
|
|
110 |
self.max_samples = max_samples
|
111 |
self.grad_accumulation_steps = grad_accumulation_steps
|
112 |
self.max_grad_norm = max_grad_norm
|
113 |
+
|
114 |
+
# mel vocoder config
|
115 |
self.vocoder_name = mel_spec_type
|
116 |
+
self.is_local_vocoder = is_local_vocoder
|
117 |
+
self.local_vocoder_path = local_vocoder_path
|
118 |
|
119 |
self.noise_scheduler = noise_scheduler
|
120 |
|
|
|
205 |
if self.log_samples:
|
206 |
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
|
207 |
|
208 |
+
vocoder = load_vocoder(
|
209 |
+
vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
|
210 |
+
)
|
211 |
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
|
212 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
213 |
os.makedirs(log_samples_path, exist_ok=True)
|