Nekochu commited on
Commit
2e395ab
·
1 Parent(s): 81f54b1

chunk latents into ~30s segments for faster CPU training, energy-aware boundaries

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. train_engine.py +88 -43
app.py CHANGED
@@ -749,7 +749,8 @@ def gradio_main():
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
752
- _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
 
753
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
754
 
755
  if processed == 0:
 
749
  processed = result.get("processed", 0)
750
  failed = result.get("failed", 0)
751
  total = result.get("total", 0)
752
+ chunks = result.get("chunks", processed)
753
+ _log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
754
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
755
 
756
  if processed == 0:
train_engine.py CHANGED
@@ -64,6 +64,10 @@ logger = logging.getLogger(__name__)
64
  MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
65
  MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
66
  TARGET_SR = 48000
 
 
 
 
67
  AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
68
 
69
  # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
@@ -778,6 +782,46 @@ def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
778
  return hs, mask
779
 
780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  # ============================================================================
782
  # VAE TILED ENCODING
783
  # ============================================================================
@@ -2148,18 +2192,13 @@ def preprocess_audio(
2148
  del target_latents
2149
  continue
2150
 
2151
- lat_len = target_latents.shape[1]
2152
- att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
2153
-
2154
- # Auto-caption: read existing sidecar or analyze
2155
  sidecar = _read_caption_sidecar(af)
2156
  if sidecar is not None:
2157
  caption = sidecar.get("caption", "") or af.stem
2158
  lyrics = sidecar.get("lyrics", "[Instrumental]")
2159
  logger.info("[Caption] %s: using existing sidecar", af.name)
2160
  else:
2161
- # Auto-select analysis mode based on dataset size
2162
- # mid/sas use Demucs stem separation — GPU only
2163
  if device == "cpu":
2164
  analysis_mode = "faf"
2165
  elif total <= 20:
@@ -2169,7 +2208,6 @@ def preprocess_audio(
2169
  else:
2170
  analysis_mode = "faf"
2171
 
2172
- # Log mode selection with reasoning (first file only)
2173
  if i == 0:
2174
  _MODE_DESC = {
2175
  "faf": "fast, ~3s/file",
@@ -2177,19 +2215,9 @@ def preprocess_audio(
2177
  "sas": "best quality, ~30s/file on GPU, slower on CPU",
2178
  }
2179
  logger.info(
2180
- "[Analysis] Mode auto-selected: '%s' (%s) "
2181
- "for %d files (<=20: sas, 21-100: mid, 100+: faf)",
2182
  analysis_mode, _MODE_DESC[analysis_mode], total,
2183
  )
2184
- if analysis_mode in ("mid", "sas") and device == "cpu":
2185
- logger.warning(
2186
- "[Analysis] Mode '%s' uses Demucs stem separation "
2187
- "which is SLOW on CPU (~2-5 min/file). "
2188
- "Total estimated time: ~%d-%d min for %d files. "
2189
- "Use 'faf' mode or a GPU machine for faster processing.",
2190
- analysis_mode,
2191
- total * 2, total * 5, total,
2192
- )
2193
 
2194
  try:
2195
  logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
@@ -2204,10 +2232,9 @@ def preprocess_audio(
2204
  logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
2205
  caption = af.stem
2206
  lyrics = "[Instrumental]"
2207
- text_prompt = caption
2208
 
2209
  with torch.no_grad():
2210
- text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype)
2211
  lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
2212
 
2213
  has_bad = any(
@@ -2216,32 +2243,49 @@ def preprocess_audio(
2216
  )
2217
  if has_bad:
2218
  p1_failed += 1
2219
- del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
2220
  continue
2221
 
2222
- tmp_path = out / f"{stem}.tmp.pt"
2223
- torch.save({
2224
- "target_latents": target_latents.squeeze(0).cpu(),
2225
- "attention_mask": att_mask.squeeze(0).cpu(),
2226
- "text_hidden_states": text_hs.cpu(),
2227
- "text_attention_mask": text_mask.cpu(),
2228
- "lyric_hidden_states": lyric_hs.cpu(),
2229
- "lyric_attention_mask": lyric_mask.cpu(),
2230
- "silence_latent": silence_lat.cpu(),
2231
- "latent_length": lat_len,
2232
- "metadata": {
2233
- "audio_path": str(af),
2234
- "filename": af.name,
2235
- "caption": caption,
2236
- "lyrics": lyrics,
2237
- },
2238
- }, tmp_path)
2239
-
2240
- del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
2241
- intermediates.append(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2242
 
2243
  if progress_callback:
2244
- progress_callback(i + 1, total, f"[Pass 1] {af.name}")
2245
 
2246
  except Exception as exc:
2247
  p1_failed += 1
@@ -2329,7 +2373,8 @@ def preprocess_audio(
2329
  _clear_gpu_cache(device)
2330
 
2331
  failed = p1_failed + p2_failed
2332
- return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
 
2333
 
2334
 
2335
  # ============================================================================
 
64
  MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
65
  MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
66
  TARGET_SR = 48000
67
+ LATENT_HZ = 25 # latent frames per second (48000 / 1920)
68
+ CHUNK_LATENT_MIN = 20 * LATENT_HZ # 500 frames (20s)
69
+ CHUNK_LATENT_TARGET = 30 * LATENT_HZ # 750 frames (30s)
70
+ CHUNK_LATENT_MAX = 40 * LATENT_HZ # 1000 frames (40s)
71
  AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
72
 
73
  # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
 
782
  return hs, mask
783
 
784
 
785
+ # ============================================================================
786
+ # LATENT CHUNKING (split long latents into ~30s training samples)
787
+ # ============================================================================
788
+
789
+ def _chunk_latents(latent: torch.Tensor) -> List[torch.Tensor]:
790
+ """Split a [T, C] latent into ~30s chunks for faster training.
791
+
792
+ Uses energy-based boundary detection: finds the lowest-energy frame
793
+ within the 20-40s window around each cut point, avoiding cuts through
794
+ loud notes. Short files (<=40s) are returned as-is.
795
+ """
796
+ T = latent.shape[0]
797
+ if T <= CHUNK_LATENT_MAX:
798
+ return [latent]
799
+
800
+ energy = latent.pow(2).mean(dim=-1) # [T] per-frame energy
801
+
802
+ chunks = []
803
+ pos = 0
804
+ while pos < T:
805
+ remaining = T - pos
806
+ if remaining <= CHUNK_LATENT_MAX:
807
+ if chunks and remaining < CHUNK_LATENT_MIN:
808
+ # Merge short tail into the previous chunk
809
+ chunks[-1] = latent[pos - chunks[-1].shape[0]:]
810
+ else:
811
+ chunks.append(latent[pos:])
812
+ break
813
+
814
+ search_start = pos + CHUNK_LATENT_MIN
815
+ search_end = min(pos + CHUNK_LATENT_MAX, T)
816
+ window = energy[search_start:search_end]
817
+ cut = search_start + window.argmin().item()
818
+
819
+ chunks.append(latent[pos:cut])
820
+ pos = cut
821
+
822
+ return chunks
823
+
824
+
825
  # ============================================================================
826
  # VAE TILED ENCODING
827
  # ============================================================================
 
2192
  del target_latents
2193
  continue
2194
 
2195
+ # Auto-caption (once per file, shared across chunks)
 
 
 
2196
  sidecar = _read_caption_sidecar(af)
2197
  if sidecar is not None:
2198
  caption = sidecar.get("caption", "") or af.stem
2199
  lyrics = sidecar.get("lyrics", "[Instrumental]")
2200
  logger.info("[Caption] %s: using existing sidecar", af.name)
2201
  else:
 
 
2202
  if device == "cpu":
2203
  analysis_mode = "faf"
2204
  elif total <= 20:
 
2208
  else:
2209
  analysis_mode = "faf"
2210
 
 
2211
  if i == 0:
2212
  _MODE_DESC = {
2213
  "faf": "fast, ~3s/file",
 
2215
  "sas": "best quality, ~30s/file on GPU, slower on CPU",
2216
  }
2217
  logger.info(
2218
+ "[Analysis] Mode '%s' (%s) for %d files",
 
2219
  analysis_mode, _MODE_DESC[analysis_mode], total,
2220
  )
 
 
 
 
 
 
 
 
 
2221
 
2222
  try:
2223
  logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
 
2232
  logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
2233
  caption = af.stem
2234
  lyrics = "[Instrumental]"
 
2235
 
2236
  with torch.no_grad():
2237
+ text_hs, text_mask = encode_text(text_enc, tokenizer, caption, device, dtype)
2238
  lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
2239
 
2240
  has_bad = any(
 
2243
  )
2244
  if has_bad:
2245
  p1_failed += 1
2246
+ del target_latents, text_hs, text_mask, lyric_hs, lyric_mask
2247
  continue
2248
 
2249
+ # Chunk latents into ~30s segments for faster training
2250
+ full_lat = target_latents.squeeze(0).cpu() # [T, C]
2251
+ T = full_lat.shape[0]
2252
+ chunks = _chunk_latents(full_lat)
2253
+ logger.info("[Chunk] %s: %d frames -> %d chunks", af.name, T, len(chunks))
2254
+
2255
+ text_hs_cpu = text_hs.cpu()
2256
+ text_mask_cpu = text_mask.cpu()
2257
+ lyric_hs_cpu = lyric_hs.cpu()
2258
+ lyric_mask_cpu = lyric_mask.cpu()
2259
+ silence_cpu = silence_lat.cpu()
2260
+ meta = {
2261
+ "audio_path": str(af),
2262
+ "filename": af.name,
2263
+ "caption": caption,
2264
+ "lyrics": lyrics,
2265
+ }
2266
+
2267
+ for ci, chunk_lat in enumerate(chunks):
2268
+ chunk_len = chunk_lat.shape[0]
2269
+ chunk_mask = torch.ones(chunk_len, dtype=dtype)
2270
+ tag = f"{stem}_chunk{ci}" if len(chunks) > 1 else stem
2271
+ tmp_path = out / f"{tag}.tmp.pt"
2272
+ torch.save({
2273
+ "target_latents": chunk_lat,
2274
+ "attention_mask": chunk_mask,
2275
+ "text_hidden_states": text_hs_cpu,
2276
+ "text_attention_mask": text_mask_cpu,
2277
+ "lyric_hidden_states": lyric_hs_cpu,
2278
+ "lyric_attention_mask": lyric_mask_cpu,
2279
+ "silence_latent": silence_cpu,
2280
+ "latent_length": chunk_len,
2281
+ "metadata": meta,
2282
+ }, tmp_path)
2283
+ intermediates.append(tmp_path)
2284
+
2285
+ del target_latents, full_lat, text_hs, text_mask, lyric_hs, lyric_mask
2286
 
2287
  if progress_callback:
2288
+ progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(chunks)} chunks)")
2289
 
2290
  except Exception as exc:
2291
  p1_failed += 1
 
2373
  _clear_gpu_cache(device)
2374
 
2375
  failed = p1_failed + p2_failed
2376
+ return {"processed": processed, "failed": failed, "total": total,
2377
+ "chunks": len(intermediates), "output_dir": str(out)}
2378
 
2379
 
2380
  # ============================================================================