Nekochu commited on
Commit
1ee8f1f
·
1 Parent(s): 2e395ab

audio-level chunking (not latent), auto-scale epochs for chunk count

Browse files
Files changed (2) hide show
  1. app.py +10 -2
  2. train_engine.py +56 -42
app.py CHANGED
@@ -749,7 +749,6 @@ def gradio_main():
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
752
- chunks = result.get("chunks", processed)
753
  _log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
754
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
755
 
@@ -758,6 +757,15 @@ def gradio_main():
758
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
759
  return
760
 
 
 
 
 
 
 
 
 
 
761
  _gc.collect()
762
 
763
  # -- Phase 2: Training --
@@ -768,7 +776,7 @@ def gradio_main():
768
  dataset_dir=preprocessed_dir,
769
  output_dir=adapter_out,
770
  checkpoint_dir=ACE_CHECKPOINT_DIR,
771
- epochs=epochs,
772
  lr=lr,
773
  rank=rank,
774
  alpha=rank * 2,
 
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
 
752
  _log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
753
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
754
 
 
757
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
758
  return
759
 
760
+ # Auto-scale epochs: chunking multiplies samples, reduce epochs
761
+ # to keep total gradient updates ~constant
762
+ effective_epochs = epochs
763
+ if processed > total and total > 0:
764
+ scale = total / processed
765
+ effective_epochs = max(10, int(epochs * scale))
766
+ _log(f"[INFO] Auto-scaled epochs: {epochs} -> {effective_epochs} "
767
+ f"(chunking: {total} files -> {processed} samples)")
768
+
769
  _gc.collect()
770
 
771
  # -- Phase 2: Training --
 
776
  dataset_dir=preprocessed_dir,
777
  output_dir=adapter_out,
778
  checkpoint_dir=ACE_CHECKPOINT_DIR,
779
+ epochs=effective_epochs,
780
  lr=lr,
781
  rank=rank,
782
  alpha=rank * 2,
train_engine.py CHANGED
@@ -783,40 +783,54 @@ def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
783
 
784
 
785
  # ============================================================================
786
- # LATENT CHUNKING (split long latents into ~30s training samples)
787
  # ============================================================================
788
 
789
- def _chunk_latents(latent: torch.Tensor) -> List[torch.Tensor]:
790
- """Split a [T, C] latent into ~30s chunks for faster training.
791
 
792
- Uses energy-based boundary detection: finds the lowest-energy frame
793
- within the 20-40s window around each cut point, avoiding cuts through
794
- loud notes. Short files (<=40s) are returned as-is.
 
 
 
795
  """
796
- T = latent.shape[0]
797
- if T <= CHUNK_LATENT_MAX:
798
- return [latent]
 
 
 
 
 
 
 
 
799
 
800
- energy = latent.pow(2).mean(dim=-1) # [T] per-frame energy
 
801
 
802
  chunks = []
803
  pos = 0
804
- while pos < T:
805
- remaining = T - pos
806
- if remaining <= CHUNK_LATENT_MAX:
807
- if chunks and remaining < CHUNK_LATENT_MIN:
808
- # Merge short tail into the previous chunk
809
- chunks[-1] = latent[pos - chunks[-1].shape[0]:]
810
- else:
811
- chunks.append(latent[pos:])
812
  break
813
 
814
- search_start = pos + CHUNK_LATENT_MIN
815
- search_end = min(pos + CHUNK_LATENT_MAX, T)
816
- window = energy[search_start:search_end]
817
  cut = search_start + window.argmin().item()
818
 
819
- chunks.append(latent[pos:cut])
 
 
 
 
 
 
820
  pos = cut
821
 
822
  return chunks
@@ -2181,16 +2195,6 @@ def preprocess_audio(
2181
 
2182
  try:
2183
  audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
2184
- audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
2185
-
2186
- with torch.no_grad():
2187
- target_latents = tiled_vae_encode(vae, audio, dtype)
2188
- del audio
2189
-
2190
- if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
2191
- p1_failed += 1
2192
- del target_latents
2193
- continue
2194
 
2195
  # Auto-caption (once per file, shared across chunks)
2196
  sidecar = _read_caption_sidecar(af)
@@ -2243,14 +2247,13 @@ def preprocess_audio(
2243
  )
2244
  if has_bad:
2245
  p1_failed += 1
2246
- del target_latents, text_hs, text_mask, lyric_hs, lyric_mask
2247
  continue
2248
 
2249
- # Chunk latents into ~30s segments for faster training
2250
- full_lat = target_latents.squeeze(0).cpu() # [T, C]
2251
- T = full_lat.shape[0]
2252
- chunks = _chunk_latents(full_lat)
2253
- logger.info("[Chunk] %s: %d frames -> %d chunks", af.name, T, len(chunks))
2254
 
2255
  text_hs_cpu = text_hs.cpu()
2256
  text_mask_cpu = text_mask.cpu()
@@ -2263,11 +2266,21 @@ def preprocess_audio(
2263
  "caption": caption,
2264
  "lyrics": lyrics,
2265
  }
 
2266
 
2267
- for ci, chunk_lat in enumerate(chunks):
 
 
 
 
 
 
 
 
 
2268
  chunk_len = chunk_lat.shape[0]
2269
  chunk_mask = torch.ones(chunk_len, dtype=dtype)
2270
- tag = f"{stem}_chunk{ci}" if len(chunks) > 1 else stem
2271
  tmp_path = out / f"{tag}.tmp.pt"
2272
  torch.save({
2273
  "target_latents": chunk_lat,
@@ -2281,11 +2294,12 @@ def preprocess_audio(
2281
  "metadata": meta,
2282
  }, tmp_path)
2283
  intermediates.append(tmp_path)
 
2284
 
2285
- del target_latents, full_lat, text_hs, text_mask, lyric_hs, lyric_mask
2286
 
2287
  if progress_callback:
2288
- progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(chunks)} chunks)")
2289
 
2290
  except Exception as exc:
2291
  p1_failed += 1
 
783
 
784
 
785
  # ============================================================================
786
+ # AUDIO CHUNKING (split long audio into ~30s training samples)
787
  # ============================================================================
788
 
789
+ CHUNK_MIN_SAMPLES = 20 * TARGET_SR # 20s
790
+ CHUNK_MAX_SAMPLES = 40 * TARGET_SR # 40s
791
 
792
+ def _chunk_audio(audio: torch.Tensor) -> List[torch.Tensor]:
793
+ """Split a [C, S] audio tensor into ~30s chunks for faster training.
794
+
795
+ Uses RMS energy to find the quietest point within the 20-40s window
796
+ around each cut, avoiding cuts through loud notes.
797
+ Short files (<=40s) are returned as-is.
798
  """
799
+ S = audio.shape[-1]
800
+ if S <= CHUNK_MAX_SAMPLES:
801
+ return [audio]
802
+
803
+ mono = audio.mean(dim=0) # [S]
804
+ hop = TARGET_SR // 10 # 0.1s resolution
805
+ frame_count = S // hop
806
+ rms = torch.zeros(frame_count)
807
+ for fi in range(frame_count):
808
+ seg = mono[fi * hop:(fi + 1) * hop]
809
+ rms[fi] = seg.pow(2).mean().sqrt()
810
 
811
+ min_frames = 20 * 10 # 20s in 0.1s frames
812
+ max_frames = 40 * 10 # 40s
813
 
814
  chunks = []
815
  pos = 0
816
+ while pos < frame_count:
817
+ remaining = frame_count - pos
818
+ if remaining <= max_frames:
819
+ chunks.append(audio[:, pos * hop:])
 
 
 
 
820
  break
821
 
822
+ search_start = pos + min_frames
823
+ search_end = min(pos + max_frames, frame_count)
824
+ window = rms[search_start:search_end]
825
  cut = search_start + window.argmin().item()
826
 
827
+ # If cutting here leaves a tail shorter than 20s, take it all
828
+ tail = frame_count - cut
829
+ if tail < min_frames:
830
+ chunks.append(audio[:, pos * hop:])
831
+ break
832
+
833
+ chunks.append(audio[:, pos * hop:cut * hop])
834
  pos = cut
835
 
836
  return chunks
 
2195
 
2196
  try:
2197
  audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
 
 
 
 
 
 
 
 
 
 
2198
 
2199
  # Auto-caption (once per file, shared across chunks)
2200
  sidecar = _read_caption_sidecar(af)
 
2247
  )
2248
  if has_bad:
2249
  p1_failed += 1
2250
+ del text_hs, text_mask, lyric_hs, lyric_mask
2251
  continue
2252
 
2253
+ # Split audio into ~30s chunks, VAE encode each independently
2254
+ audio_chunks = _chunk_audio(audio)
2255
+ logger.info("[Chunk] %s: %.1fs -> %d chunks", af.name,
2256
+ audio.shape[-1] / TARGET_SR, len(audio_chunks))
 
2257
 
2258
  text_hs_cpu = text_hs.cpu()
2259
  text_mask_cpu = text_mask.cpu()
 
2266
  "caption": caption,
2267
  "lyrics": lyrics,
2268
  }
2269
+ del text_hs, text_mask, lyric_hs, lyric_mask
2270
 
2271
+ for ci, chunk_audio in enumerate(audio_chunks):
2272
+ chunk_in = chunk_audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
2273
+ with torch.no_grad():
2274
+ chunk_lat = tiled_vae_encode(vae, chunk_in, dtype)
2275
+ del chunk_in
2276
+
2277
+ if torch.isnan(chunk_lat).any() or torch.isinf(chunk_lat).any():
2278
+ continue
2279
+
2280
+ chunk_lat = chunk_lat.squeeze(0).cpu()
2281
  chunk_len = chunk_lat.shape[0]
2282
  chunk_mask = torch.ones(chunk_len, dtype=dtype)
2283
+ tag = f"{stem}_chunk{ci}" if len(audio_chunks) > 1 else stem
2284
  tmp_path = out / f"{tag}.tmp.pt"
2285
  torch.save({
2286
  "target_latents": chunk_lat,
 
2294
  "metadata": meta,
2295
  }, tmp_path)
2296
  intermediates.append(tmp_path)
2297
+ del chunk_lat
2298
 
2299
+ del audio
2300
 
2301
  if progress_callback:
2302
+ progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(audio_chunks)} chunks)")
2303
 
2304
  except Exception as exc:
2305
  p1_failed += 1