Spaces:
Running
Running
random 60s crop at training time (matches Side-Step chunk-duration), remove pre-split chunking
Browse files- app.py +4 -12
- train_engine.py +47 -47
app.py
CHANGED
|
@@ -749,7 +749,7 @@ def gradio_main():
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
| 752 |
-
_log(f"[OK] Preprocessed: {total} files
|
| 753 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 754 |
|
| 755 |
if processed == 0:
|
|
@@ -757,18 +757,9 @@ def gradio_main():
|
|
| 757 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 758 |
return
|
| 759 |
|
| 760 |
-
# Auto-scale epochs: chunking multiplies samples, reduce epochs
|
| 761 |
-
# to keep total gradient updates ~constant
|
| 762 |
-
effective_epochs = epochs
|
| 763 |
-
if processed > total and total > 0:
|
| 764 |
-
scale = total / processed
|
| 765 |
-
effective_epochs = max(10, int(epochs * scale))
|
| 766 |
-
_log(f"[INFO] Auto-scaled epochs: {epochs} -> {effective_epochs} "
|
| 767 |
-
f"(chunking: {total} files -> {processed} samples)")
|
| 768 |
-
|
| 769 |
_gc.collect()
|
| 770 |
|
| 771 |
-
# -- Phase 2: Training --
|
| 772 |
_log("[Step 2/2] Training LoRA...")
|
| 773 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 774 |
|
|
@@ -776,7 +767,7 @@ def gradio_main():
|
|
| 776 |
dataset_dir=preprocessed_dir,
|
| 777 |
output_dir=adapter_out,
|
| 778 |
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 779 |
-
epochs=
|
| 780 |
lr=lr,
|
| 781 |
rank=rank,
|
| 782 |
alpha=rank * 2,
|
|
@@ -790,6 +781,7 @@ def gradio_main():
|
|
| 790 |
seed=42,
|
| 791 |
variant="turbo",
|
| 792 |
device="cpu",
|
|
|
|
| 793 |
log_every=5,
|
| 794 |
):
|
| 795 |
elapsed = time.time() - train_start
|
|
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
| 752 |
+
_log(f"[OK] Preprocessed: {processed}/{total} files (failed: {failed})")
|
| 753 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 754 |
|
| 755 |
if processed == 0:
|
|
|
|
| 757 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 758 |
return
|
| 759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
_gc.collect()
|
| 761 |
|
| 762 |
+
# -- Phase 2: Training (random 60s crops for speed + augmentation) --
|
| 763 |
_log("[Step 2/2] Training LoRA...")
|
| 764 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 765 |
|
|
|
|
| 767 |
dataset_dir=preprocessed_dir,
|
| 768 |
output_dir=adapter_out,
|
| 769 |
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 770 |
+
epochs=epochs,
|
| 771 |
lr=lr,
|
| 772 |
rank=rank,
|
| 773 |
alpha=rank * 2,
|
|
|
|
| 781 |
seed=42,
|
| 782 |
variant="turbo",
|
| 783 |
device="cpu",
|
| 784 |
+
chunk_duration=60,
|
| 785 |
log_every=5,
|
| 786 |
):
|
| 787 |
elapsed = time.time() - train_start
|
train_engine.py
CHANGED
|
@@ -2250,56 +2250,45 @@ def preprocess_audio(
|
|
| 2250 |
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 2251 |
continue
|
| 2252 |
|
| 2253 |
-
#
|
| 2254 |
-
|
| 2255 |
-
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
text_hs_cpu = text_hs.cpu()
|
| 2259 |
-
text_mask_cpu = text_mask.cpu()
|
| 2260 |
-
lyric_hs_cpu = lyric_hs.cpu()
|
| 2261 |
-
lyric_mask_cpu = lyric_mask.cpu()
|
| 2262 |
-
silence_cpu = silence_lat.cpu()
|
| 2263 |
-
meta = {
|
| 2264 |
-
"audio_path": str(af),
|
| 2265 |
-
"filename": af.name,
|
| 2266 |
-
"caption": caption,
|
| 2267 |
-
"lyrics": lyrics,
|
| 2268 |
-
}
|
| 2269 |
-
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 2270 |
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
|
| 2274 |
-
|
| 2275 |
-
|
| 2276 |
-
|
| 2277 |
-
|
| 2278 |
-
|
| 2279 |
-
|
| 2280 |
-
|
| 2281 |
-
|
| 2282 |
-
|
| 2283 |
-
|
| 2284 |
-
|
| 2285 |
-
|
| 2286 |
-
|
| 2287 |
-
|
| 2288 |
-
|
| 2289 |
-
|
| 2290 |
-
|
| 2291 |
-
"
|
| 2292 |
-
"
|
| 2293 |
-
"
|
| 2294 |
-
"
|
| 2295 |
-
},
|
| 2296 |
-
|
| 2297 |
-
|
| 2298 |
-
|
| 2299 |
-
del
|
|
|
|
| 2300 |
|
| 2301 |
if progress_callback:
|
| 2302 |
-
progress_callback(i + 1, total, f"[Pass 1] {af.name}
|
| 2303 |
|
| 2304 |
except Exception as exc:
|
| 2305 |
p1_failed += 1
|
|
@@ -2419,6 +2408,7 @@ def train_lora_generator(
|
|
| 2419 |
target_modules: Optional[List[str]] = None,
|
| 2420 |
log_every: int = 10,
|
| 2421 |
resume_from: Optional[str] = None,
|
|
|
|
| 2422 |
) -> Generator[str, None, None]:
|
| 2423 |
"""Run LoRA training, yielding progress strings each epoch.
|
| 2424 |
|
|
@@ -2634,6 +2624,16 @@ def train_lora_generator(
|
|
| 2634 |
enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 2635 |
ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 2636 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2637 |
bsz = tgt.shape[0]
|
| 2638 |
|
| 2639 |
# CFG dropout
|
|
|
|
| 2250 |
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 2251 |
continue
|
| 2252 |
|
| 2253 |
+
# VAE encode full audio (tiled for memory, output is full-length)
|
| 2254 |
+
audio_in = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
|
| 2255 |
+
with torch.no_grad():
|
| 2256 |
+
target_latents = tiled_vae_encode(vae, audio_in, dtype)
|
| 2257 |
+
del audio_in, audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2258 |
|
| 2259 |
+
if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
|
| 2260 |
+
p1_failed += 1
|
| 2261 |
+
del target_latents, text_hs, text_mask, lyric_hs, lyric_mask
|
| 2262 |
+
continue
|
| 2263 |
+
|
| 2264 |
+
lat = target_latents.squeeze(0).cpu()
|
| 2265 |
+
lat_len = lat.shape[0]
|
| 2266 |
+
att_mask = torch.ones(lat_len, dtype=dtype)
|
| 2267 |
+
|
| 2268 |
+
tmp_path = out / f"{stem}.tmp.pt"
|
| 2269 |
+
torch.save({
|
| 2270 |
+
"target_latents": lat,
|
| 2271 |
+
"attention_mask": att_mask,
|
| 2272 |
+
"text_hidden_states": text_hs.cpu(),
|
| 2273 |
+
"text_attention_mask": text_mask.cpu(),
|
| 2274 |
+
"lyric_hidden_states": lyric_hs.cpu(),
|
| 2275 |
+
"lyric_attention_mask": lyric_mask.cpu(),
|
| 2276 |
+
"silence_latent": silence_lat.cpu(),
|
| 2277 |
+
"latent_length": lat_len,
|
| 2278 |
+
"metadata": {
|
| 2279 |
+
"audio_path": str(af),
|
| 2280 |
+
"filename": af.name,
|
| 2281 |
+
"caption": caption,
|
| 2282 |
+
"lyrics": lyrics,
|
| 2283 |
+
},
|
| 2284 |
+
}, tmp_path)
|
| 2285 |
+
intermediates.append(tmp_path)
|
| 2286 |
+
|
| 2287 |
+
del target_latents, lat, text_hs, text_mask, lyric_hs, lyric_mask
|
| 2288 |
+
logger.info("[OK] %s: %d latent frames (%.1fs)", af.name, lat_len, lat_len / LATENT_HZ)
|
| 2289 |
|
| 2290 |
if progress_callback:
|
| 2291 |
+
progress_callback(i + 1, total, f"[Pass 1] {af.name}")
|
| 2292 |
|
| 2293 |
except Exception as exc:
|
| 2294 |
p1_failed += 1
|
|
|
|
| 2408 |
target_modules: Optional[List[str]] = None,
|
| 2409 |
log_every: int = 10,
|
| 2410 |
resume_from: Optional[str] = None,
|
| 2411 |
+
chunk_duration: float = 0,
|
| 2412 |
) -> Generator[str, None, None]:
|
| 2413 |
"""Run LoRA training, yielding progress strings each epoch.
|
| 2414 |
|
|
|
|
| 2624 |
enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 2625 |
ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 2626 |
|
| 2627 |
+
# Random crop to chunk_duration (data augmentation + speed)
|
| 2628 |
+
if chunk_duration > 0:
|
| 2629 |
+
max_len = int(chunk_duration * LATENT_HZ)
|
| 2630 |
+
T = tgt.shape[1]
|
| 2631 |
+
if T > max_len:
|
| 2632 |
+
start = random.randint(0, T - max_len)
|
| 2633 |
+
tgt = tgt[:, start:start + max_len, :]
|
| 2634 |
+
att = att[:, start:start + max_len]
|
| 2635 |
+
ctx = ctx[:, start:start + max_len, :]
|
| 2636 |
+
|
| 2637 |
bsz = tgt.shape[0]
|
| 2638 |
|
| 2639 |
# CFG dropout
|