Nekochu commited on
Commit
d3618ec
·
1 Parent(s): 1ee8f1f

random 60s crop at training time (matches Side-Step chunk-duration), remove pre-split chunking

Browse files
Files changed (2) hide show
  1. app.py +4 -12
  2. train_engine.py +47 -47
app.py CHANGED
@@ -749,7 +749,7 @@ def gradio_main():
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
752
- _log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
753
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
754
 
755
  if processed == 0:
@@ -757,18 +757,9 @@ def gradio_main():
757
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
758
  return
759
 
760
- # Auto-scale epochs: chunking multiplies samples, reduce epochs
761
- # to keep total gradient updates ~constant
762
- effective_epochs = epochs
763
- if processed > total and total > 0:
764
- scale = total / processed
765
- effective_epochs = max(10, int(epochs * scale))
766
- _log(f"[INFO] Auto-scaled epochs: {epochs} -> {effective_epochs} "
767
- f"(chunking: {total} files -> {processed} samples)")
768
-
769
  _gc.collect()
770
 
771
- # -- Phase 2: Training --
772
  _log("[Step 2/2] Training LoRA...")
773
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
774
 
@@ -776,7 +767,7 @@ def gradio_main():
776
  dataset_dir=preprocessed_dir,
777
  output_dir=adapter_out,
778
  checkpoint_dir=ACE_CHECKPOINT_DIR,
779
- epochs=effective_epochs,
780
  lr=lr,
781
  rank=rank,
782
  alpha=rank * 2,
@@ -790,6 +781,7 @@ def gradio_main():
790
  seed=42,
791
  variant="turbo",
792
  device="cpu",
 
793
  log_every=5,
794
  ):
795
  elapsed = time.time() - train_start
 
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
752
+ _log(f"[OK] Preprocessed: {processed}/{total} files (failed: {failed})")
753
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
754
 
755
  if processed == 0:
 
757
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
758
  return
759
 
 
 
 
 
 
 
 
 
 
760
  _gc.collect()
761
 
762
+ # -- Phase 2: Training (random 60s crops for speed + augmentation) --
763
  _log("[Step 2/2] Training LoRA...")
764
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
765
 
 
767
  dataset_dir=preprocessed_dir,
768
  output_dir=adapter_out,
769
  checkpoint_dir=ACE_CHECKPOINT_DIR,
770
+ epochs=epochs,
771
  lr=lr,
772
  rank=rank,
773
  alpha=rank * 2,
 
781
  seed=42,
782
  variant="turbo",
783
  device="cpu",
784
+ chunk_duration=60,
785
  log_every=5,
786
  ):
787
  elapsed = time.time() - train_start
train_engine.py CHANGED
@@ -2250,56 +2250,45 @@ def preprocess_audio(
2250
  del text_hs, text_mask, lyric_hs, lyric_mask
2251
  continue
2252
 
2253
- # Split audio into ~30s chunks, VAE encode each independently
2254
- audio_chunks = _chunk_audio(audio)
2255
- logger.info("[Chunk] %s: %.1fs -> %d chunks", af.name,
2256
- audio.shape[-1] / TARGET_SR, len(audio_chunks))
2257
-
2258
- text_hs_cpu = text_hs.cpu()
2259
- text_mask_cpu = text_mask.cpu()
2260
- lyric_hs_cpu = lyric_hs.cpu()
2261
- lyric_mask_cpu = lyric_mask.cpu()
2262
- silence_cpu = silence_lat.cpu()
2263
- meta = {
2264
- "audio_path": str(af),
2265
- "filename": af.name,
2266
- "caption": caption,
2267
- "lyrics": lyrics,
2268
- }
2269
- del text_hs, text_mask, lyric_hs, lyric_mask
2270
 
2271
- for ci, chunk_audio in enumerate(audio_chunks):
2272
- chunk_in = chunk_audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
2273
- with torch.no_grad():
2274
- chunk_lat = tiled_vae_encode(vae, chunk_in, dtype)
2275
- del chunk_in
2276
-
2277
- if torch.isnan(chunk_lat).any() or torch.isinf(chunk_lat).any():
2278
- continue
2279
-
2280
- chunk_lat = chunk_lat.squeeze(0).cpu()
2281
- chunk_len = chunk_lat.shape[0]
2282
- chunk_mask = torch.ones(chunk_len, dtype=dtype)
2283
- tag = f"{stem}_chunk{ci}" if len(audio_chunks) > 1 else stem
2284
- tmp_path = out / f"{tag}.tmp.pt"
2285
- torch.save({
2286
- "target_latents": chunk_lat,
2287
- "attention_mask": chunk_mask,
2288
- "text_hidden_states": text_hs_cpu,
2289
- "text_attention_mask": text_mask_cpu,
2290
- "lyric_hidden_states": lyric_hs_cpu,
2291
- "lyric_attention_mask": lyric_mask_cpu,
2292
- "silence_latent": silence_cpu,
2293
- "latent_length": chunk_len,
2294
- "metadata": meta,
2295
- }, tmp_path)
2296
- intermediates.append(tmp_path)
2297
- del chunk_lat
2298
-
2299
- del audio
 
2300
 
2301
  if progress_callback:
2302
- progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(audio_chunks)} chunks)")
2303
 
2304
  except Exception as exc:
2305
  p1_failed += 1
@@ -2419,6 +2408,7 @@ def train_lora_generator(
2419
  target_modules: Optional[List[str]] = None,
2420
  log_every: int = 10,
2421
  resume_from: Optional[str] = None,
 
2422
  ) -> Generator[str, None, None]:
2423
  """Run LoRA training, yielding progress strings each epoch.
2424
 
@@ -2634,6 +2624,16 @@ def train_lora_generator(
2634
  enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
2635
  ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
2636
 
 
 
 
 
 
 
 
 
 
 
2637
  bsz = tgt.shape[0]
2638
 
2639
  # CFG dropout
 
2250
  del text_hs, text_mask, lyric_hs, lyric_mask
2251
  continue
2252
 
2253
+ # VAE encode full audio (tiled for memory, output is full-length)
2254
+ audio_in = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
2255
+ with torch.no_grad():
2256
+ target_latents = tiled_vae_encode(vae, audio_in, dtype)
2257
+ del audio_in, audio
 
 
 
 
 
 
 
 
 
 
 
 
2258
 
2259
+ if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
2260
+ p1_failed += 1
2261
+ del target_latents, text_hs, text_mask, lyric_hs, lyric_mask
2262
+ continue
2263
+
2264
+ lat = target_latents.squeeze(0).cpu()
2265
+ lat_len = lat.shape[0]
2266
+ att_mask = torch.ones(lat_len, dtype=dtype)
2267
+
2268
+ tmp_path = out / f"{stem}.tmp.pt"
2269
+ torch.save({
2270
+ "target_latents": lat,
2271
+ "attention_mask": att_mask,
2272
+ "text_hidden_states": text_hs.cpu(),
2273
+ "text_attention_mask": text_mask.cpu(),
2274
+ "lyric_hidden_states": lyric_hs.cpu(),
2275
+ "lyric_attention_mask": lyric_mask.cpu(),
2276
+ "silence_latent": silence_lat.cpu(),
2277
+ "latent_length": lat_len,
2278
+ "metadata": {
2279
+ "audio_path": str(af),
2280
+ "filename": af.name,
2281
+ "caption": caption,
2282
+ "lyrics": lyrics,
2283
+ },
2284
+ }, tmp_path)
2285
+ intermediates.append(tmp_path)
2286
+
2287
+ del target_latents, lat, text_hs, text_mask, lyric_hs, lyric_mask
2288
+ logger.info("[OK] %s: %d latent frames (%.1fs)", af.name, lat_len, lat_len / LATENT_HZ)
2289
 
2290
  if progress_callback:
2291
+ progress_callback(i + 1, total, f"[Pass 1] {af.name}")
2292
 
2293
  except Exception as exc:
2294
  p1_failed += 1
 
2408
  target_modules: Optional[List[str]] = None,
2409
  log_every: int = 10,
2410
  resume_from: Optional[str] = None,
2411
+ chunk_duration: float = 0,
2412
  ) -> Generator[str, None, None]:
2413
  """Run LoRA training, yielding progress strings each epoch.
2414
 
 
2624
  enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
2625
  ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
2626
 
2627
+ # Random crop to chunk_duration (data augmentation + speed)
2628
+ if chunk_duration > 0:
2629
+ max_len = int(chunk_duration * LATENT_HZ)
2630
+ T = tgt.shape[1]
2631
+ if T > max_len:
2632
+ start = random.randint(0, T - max_len)
2633
+ tgt = tgt[:, start:start + max_len, :]
2634
+ att = att[:, start:start + max_len]
2635
+ ctx = ctx[:, start:start + max_len, :]
2636
+
2637
  bsz = tgt.shape[0]
2638
 
2639
  # CFG dropout