Spaces:
Running
Running
audio-level chunking (not latent), auto-scale epochs for chunk count
Browse files- app.py +10 -2
- train_engine.py +56 -42
app.py
CHANGED
|
@@ -749,7 +749,6 @@ def gradio_main():
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
| 752 |
-
chunks = result.get("chunks", processed)
|
| 753 |
_log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
|
| 754 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 755 |
|
|
@@ -758,6 +757,15 @@ def gradio_main():
|
|
| 758 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 759 |
return
|
| 760 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
_gc.collect()
|
| 762 |
|
| 763 |
# -- Phase 2: Training --
|
|
@@ -768,7 +776,7 @@ def gradio_main():
|
|
| 768 |
dataset_dir=preprocessed_dir,
|
| 769 |
output_dir=adapter_out,
|
| 770 |
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 771 |
-
epochs=
|
| 772 |
lr=lr,
|
| 773 |
rank=rank,
|
| 774 |
alpha=rank * 2,
|
|
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
|
|
|
| 752 |
_log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
|
| 753 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 754 |
|
|
|
|
| 757 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 758 |
return
|
| 759 |
|
| 760 |
+
# Auto-scale epochs: chunking multiplies samples, reduce epochs
|
| 761 |
+
# to keep total gradient updates ~constant
|
| 762 |
+
effective_epochs = epochs
|
| 763 |
+
if processed > total and total > 0:
|
| 764 |
+
scale = total / processed
|
| 765 |
+
effective_epochs = max(10, int(epochs * scale))
|
| 766 |
+
_log(f"[INFO] Auto-scaled epochs: {epochs} -> {effective_epochs} "
|
| 767 |
+
f"(chunking: {total} files -> {processed} samples)")
|
| 768 |
+
|
| 769 |
_gc.collect()
|
| 770 |
|
| 771 |
# -- Phase 2: Training --
|
|
|
|
| 776 |
dataset_dir=preprocessed_dir,
|
| 777 |
output_dir=adapter_out,
|
| 778 |
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 779 |
+
epochs=effective_epochs,
|
| 780 |
lr=lr,
|
| 781 |
rank=rank,
|
| 782 |
alpha=rank * 2,
|
train_engine.py
CHANGED
|
@@ -783,40 +783,54 @@ def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
|
|
| 783 |
|
| 784 |
|
| 785 |
# ============================================================================
|
| 786 |
-
#
|
| 787 |
# ============================================================================
|
| 788 |
|
| 789 |
-
|
| 790 |
-
|
| 791 |
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
|
|
|
|
|
|
|
|
|
| 795 |
"""
|
| 796 |
-
|
| 797 |
-
if
|
| 798 |
-
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
|
| 800 |
-
|
|
|
|
| 801 |
|
| 802 |
chunks = []
|
| 803 |
pos = 0
|
| 804 |
-
while pos <
|
| 805 |
-
remaining =
|
| 806 |
-
if remaining <=
|
| 807 |
-
|
| 808 |
-
# Merge short tail into the previous chunk
|
| 809 |
-
chunks[-1] = latent[pos - chunks[-1].shape[0]:]
|
| 810 |
-
else:
|
| 811 |
-
chunks.append(latent[pos:])
|
| 812 |
break
|
| 813 |
|
| 814 |
-
search_start = pos +
|
| 815 |
-
search_end = min(pos +
|
| 816 |
-
window =
|
| 817 |
cut = search_start + window.argmin().item()
|
| 818 |
|
| 819 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
pos = cut
|
| 821 |
|
| 822 |
return chunks
|
|
@@ -2181,16 +2195,6 @@ def preprocess_audio(
|
|
| 2181 |
|
| 2182 |
try:
|
| 2183 |
audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
|
| 2184 |
-
audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
|
| 2185 |
-
|
| 2186 |
-
with torch.no_grad():
|
| 2187 |
-
target_latents = tiled_vae_encode(vae, audio, dtype)
|
| 2188 |
-
del audio
|
| 2189 |
-
|
| 2190 |
-
if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
|
| 2191 |
-
p1_failed += 1
|
| 2192 |
-
del target_latents
|
| 2193 |
-
continue
|
| 2194 |
|
| 2195 |
# Auto-caption (once per file, shared across chunks)
|
| 2196 |
sidecar = _read_caption_sidecar(af)
|
|
@@ -2243,14 +2247,13 @@ def preprocess_audio(
|
|
| 2243 |
)
|
| 2244 |
if has_bad:
|
| 2245 |
p1_failed += 1
|
| 2246 |
-
del
|
| 2247 |
continue
|
| 2248 |
|
| 2249 |
-
#
|
| 2250 |
-
|
| 2251 |
-
|
| 2252 |
-
|
| 2253 |
-
logger.info("[Chunk] %s: %d frames -> %d chunks", af.name, T, len(chunks))
|
| 2254 |
|
| 2255 |
text_hs_cpu = text_hs.cpu()
|
| 2256 |
text_mask_cpu = text_mask.cpu()
|
|
@@ -2263,11 +2266,21 @@ def preprocess_audio(
|
|
| 2263 |
"caption": caption,
|
| 2264 |
"lyrics": lyrics,
|
| 2265 |
}
|
|
|
|
| 2266 |
|
| 2267 |
-
for ci,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2268 |
chunk_len = chunk_lat.shape[0]
|
| 2269 |
chunk_mask = torch.ones(chunk_len, dtype=dtype)
|
| 2270 |
-
tag = f"{stem}_chunk{ci}" if len(
|
| 2271 |
tmp_path = out / f"{tag}.tmp.pt"
|
| 2272 |
torch.save({
|
| 2273 |
"target_latents": chunk_lat,
|
|
@@ -2281,11 +2294,12 @@ def preprocess_audio(
|
|
| 2281 |
"metadata": meta,
|
| 2282 |
}, tmp_path)
|
| 2283 |
intermediates.append(tmp_path)
|
|
|
|
| 2284 |
|
| 2285 |
-
del
|
| 2286 |
|
| 2287 |
if progress_callback:
|
| 2288 |
-
progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(
|
| 2289 |
|
| 2290 |
except Exception as exc:
|
| 2291 |
p1_failed += 1
|
|
|
|
| 783 |
|
| 784 |
|
| 785 |
# ============================================================================
|
| 786 |
+
# AUDIO CHUNKING (split long audio into ~30s training samples)
|
| 787 |
# ============================================================================
|
| 788 |
|
| 789 |
+
CHUNK_MIN_SAMPLES = 20 * TARGET_SR # 20s
|
| 790 |
+
CHUNK_MAX_SAMPLES = 40 * TARGET_SR # 40s
|
| 791 |
|
| 792 |
+
def _chunk_audio(audio: torch.Tensor) -> List[torch.Tensor]:
|
| 793 |
+
"""Split a [C, S] audio tensor into ~30s chunks for faster training.
|
| 794 |
+
|
| 795 |
+
Uses RMS energy to find the quietest point within the 20-40s window
|
| 796 |
+
around each cut, avoiding cuts through loud notes.
|
| 797 |
+
Short files (<=40s) are returned as-is.
|
| 798 |
"""
|
| 799 |
+
S = audio.shape[-1]
|
| 800 |
+
if S <= CHUNK_MAX_SAMPLES:
|
| 801 |
+
return [audio]
|
| 802 |
+
|
| 803 |
+
mono = audio.mean(dim=0) # [S]
|
| 804 |
+
hop = TARGET_SR // 10 # 0.1s resolution
|
| 805 |
+
frame_count = S // hop
|
| 806 |
+
rms = torch.zeros(frame_count)
|
| 807 |
+
for fi in range(frame_count):
|
| 808 |
+
seg = mono[fi * hop:(fi + 1) * hop]
|
| 809 |
+
rms[fi] = seg.pow(2).mean().sqrt()
|
| 810 |
|
| 811 |
+
min_frames = 20 * 10 # 20s in 0.1s frames
|
| 812 |
+
max_frames = 40 * 10 # 40s
|
| 813 |
|
| 814 |
chunks = []
|
| 815 |
pos = 0
|
| 816 |
+
while pos < frame_count:
|
| 817 |
+
remaining = frame_count - pos
|
| 818 |
+
if remaining <= max_frames:
|
| 819 |
+
chunks.append(audio[:, pos * hop:])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
break
|
| 821 |
|
| 822 |
+
search_start = pos + min_frames
|
| 823 |
+
search_end = min(pos + max_frames, frame_count)
|
| 824 |
+
window = rms[search_start:search_end]
|
| 825 |
cut = search_start + window.argmin().item()
|
| 826 |
|
| 827 |
+
# If cutting here leaves a tail shorter than 20s, take it all
|
| 828 |
+
tail = frame_count - cut
|
| 829 |
+
if tail < min_frames:
|
| 830 |
+
chunks.append(audio[:, pos * hop:])
|
| 831 |
+
break
|
| 832 |
+
|
| 833 |
+
chunks.append(audio[:, pos * hop:cut * hop])
|
| 834 |
pos = cut
|
| 835 |
|
| 836 |
return chunks
|
|
|
|
| 2195 |
|
| 2196 |
try:
|
| 2197 |
audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2198 |
|
| 2199 |
# Auto-caption (once per file, shared across chunks)
|
| 2200 |
sidecar = _read_caption_sidecar(af)
|
|
|
|
| 2247 |
)
|
| 2248 |
if has_bad:
|
| 2249 |
p1_failed += 1
|
| 2250 |
+
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 2251 |
continue
|
| 2252 |
|
| 2253 |
+
# Split audio into ~30s chunks, VAE encode each independently
|
| 2254 |
+
audio_chunks = _chunk_audio(audio)
|
| 2255 |
+
logger.info("[Chunk] %s: %.1fs -> %d chunks", af.name,
|
| 2256 |
+
audio.shape[-1] / TARGET_SR, len(audio_chunks))
|
|
|
|
| 2257 |
|
| 2258 |
text_hs_cpu = text_hs.cpu()
|
| 2259 |
text_mask_cpu = text_mask.cpu()
|
|
|
|
| 2266 |
"caption": caption,
|
| 2267 |
"lyrics": lyrics,
|
| 2268 |
}
|
| 2269 |
+
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 2270 |
|
| 2271 |
+
for ci, chunk_audio in enumerate(audio_chunks):
|
| 2272 |
+
chunk_in = chunk_audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
|
| 2273 |
+
with torch.no_grad():
|
| 2274 |
+
chunk_lat = tiled_vae_encode(vae, chunk_in, dtype)
|
| 2275 |
+
del chunk_in
|
| 2276 |
+
|
| 2277 |
+
if torch.isnan(chunk_lat).any() or torch.isinf(chunk_lat).any():
|
| 2278 |
+
continue
|
| 2279 |
+
|
| 2280 |
+
chunk_lat = chunk_lat.squeeze(0).cpu()
|
| 2281 |
chunk_len = chunk_lat.shape[0]
|
| 2282 |
chunk_mask = torch.ones(chunk_len, dtype=dtype)
|
| 2283 |
+
tag = f"{stem}_chunk{ci}" if len(audio_chunks) > 1 else stem
|
| 2284 |
tmp_path = out / f"{tag}.tmp.pt"
|
| 2285 |
torch.save({
|
| 2286 |
"target_latents": chunk_lat,
|
|
|
|
| 2294 |
"metadata": meta,
|
| 2295 |
}, tmp_path)
|
| 2296 |
intermediates.append(tmp_path)
|
| 2297 |
+
del chunk_lat
|
| 2298 |
|
| 2299 |
+
del audio
|
| 2300 |
|
| 2301 |
if progress_callback:
|
| 2302 |
+
progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(audio_chunks)} chunks)")
|
| 2303 |
|
| 2304 |
except Exception as exc:
|
| 2305 |
p1_failed += 1
|