Spaces:
Running
Running
chunk latents into ~30s segments for faster CPU training, energy-aware boundaries
Browse files- app.py +2 -1
- train_engine.py +88 -43
app.py
CHANGED
|
@@ -749,7 +749,8 @@ def gradio_main():
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
| 752 |
-
|
|
|
|
| 753 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 754 |
|
| 755 |
if processed == 0:
|
|
|
|
| 749 |
processed = result.get("processed", 0)
|
| 750 |
failed = result.get("failed", 0)
|
| 751 |
total = result.get("total", 0)
|
| 752 |
+
chunks = result.get("chunks", processed)
|
| 753 |
+
_log(f"[OK] Preprocessed: {total} files -> {processed} training samples (failed: {failed})")
|
| 754 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 755 |
|
| 756 |
if processed == 0:
|
train_engine.py
CHANGED
|
@@ -64,6 +64,10 @@ logger = logging.getLogger(__name__)
|
|
| 64 |
MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
|
| 65 |
MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
|
| 66 |
TARGET_SR = 48000
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
|
| 68 |
|
| 69 |
# bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
|
|
@@ -778,6 +782,46 @@ def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
|
|
| 778 |
return hs, mask
|
| 779 |
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
# ============================================================================
|
| 782 |
# VAE TILED ENCODING
|
| 783 |
# ============================================================================
|
|
@@ -2148,18 +2192,13 @@ def preprocess_audio(
|
|
| 2148 |
del target_latents
|
| 2149 |
continue
|
| 2150 |
|
| 2151 |
-
|
| 2152 |
-
att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
|
| 2153 |
-
|
| 2154 |
-
# Auto-caption: read existing sidecar or analyze
|
| 2155 |
sidecar = _read_caption_sidecar(af)
|
| 2156 |
if sidecar is not None:
|
| 2157 |
caption = sidecar.get("caption", "") or af.stem
|
| 2158 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 2159 |
logger.info("[Caption] %s: using existing sidecar", af.name)
|
| 2160 |
else:
|
| 2161 |
-
# Auto-select analysis mode based on dataset size
|
| 2162 |
-
# mid/sas use Demucs stem separation — GPU only
|
| 2163 |
if device == "cpu":
|
| 2164 |
analysis_mode = "faf"
|
| 2165 |
elif total <= 20:
|
|
@@ -2169,7 +2208,6 @@ def preprocess_audio(
|
|
| 2169 |
else:
|
| 2170 |
analysis_mode = "faf"
|
| 2171 |
|
| 2172 |
-
# Log mode selection with reasoning (first file only)
|
| 2173 |
if i == 0:
|
| 2174 |
_MODE_DESC = {
|
| 2175 |
"faf": "fast, ~3s/file",
|
|
@@ -2177,19 +2215,9 @@ def preprocess_audio(
|
|
| 2177 |
"sas": "best quality, ~30s/file on GPU, slower on CPU",
|
| 2178 |
}
|
| 2179 |
logger.info(
|
| 2180 |
-
"[Analysis] Mode
|
| 2181 |
-
"for %d files (<=20: sas, 21-100: mid, 100+: faf)",
|
| 2182 |
analysis_mode, _MODE_DESC[analysis_mode], total,
|
| 2183 |
)
|
| 2184 |
-
if analysis_mode in ("mid", "sas") and device == "cpu":
|
| 2185 |
-
logger.warning(
|
| 2186 |
-
"[Analysis] Mode '%s' uses Demucs stem separation "
|
| 2187 |
-
"which is SLOW on CPU (~2-5 min/file). "
|
| 2188 |
-
"Total estimated time: ~%d-%d min for %d files. "
|
| 2189 |
-
"Use 'faf' mode or a GPU machine for faster processing.",
|
| 2190 |
-
analysis_mode,
|
| 2191 |
-
total * 2, total * 5, total,
|
| 2192 |
-
)
|
| 2193 |
|
| 2194 |
try:
|
| 2195 |
logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
|
|
@@ -2204,10 +2232,9 @@ def preprocess_audio(
|
|
| 2204 |
logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
|
| 2205 |
caption = af.stem
|
| 2206 |
lyrics = "[Instrumental]"
|
| 2207 |
-
text_prompt = caption
|
| 2208 |
|
| 2209 |
with torch.no_grad():
|
| 2210 |
-
text_hs, text_mask = encode_text(text_enc, tokenizer,
|
| 2211 |
lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
|
| 2212 |
|
| 2213 |
has_bad = any(
|
|
@@ -2216,32 +2243,49 @@ def preprocess_audio(
|
|
| 2216 |
)
|
| 2217 |
if has_bad:
|
| 2218 |
p1_failed += 1
|
| 2219 |
-
del target_latents,
|
| 2220 |
continue
|
| 2221 |
|
| 2222 |
-
|
| 2223 |
-
|
| 2224 |
-
|
| 2225 |
-
|
| 2226 |
-
|
| 2227 |
-
|
| 2228 |
-
|
| 2229 |
-
|
| 2230 |
-
|
| 2231 |
-
|
| 2232 |
-
|
| 2233 |
-
|
| 2234 |
-
|
| 2235 |
-
|
| 2236 |
-
|
| 2237 |
-
|
| 2238 |
-
}
|
| 2239 |
-
|
| 2240 |
-
|
| 2241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2242 |
|
| 2243 |
if progress_callback:
|
| 2244 |
-
progress_callback(i + 1, total, f"[Pass 1] {af.name}")
|
| 2245 |
|
| 2246 |
except Exception as exc:
|
| 2247 |
p1_failed += 1
|
|
@@ -2329,7 +2373,8 @@ def preprocess_audio(
|
|
| 2329 |
_clear_gpu_cache(device)
|
| 2330 |
|
| 2331 |
failed = p1_failed + p2_failed
|
| 2332 |
-
return {"processed": processed, "failed": failed, "total": total,
|
|
|
|
| 2333 |
|
| 2334 |
|
| 2335 |
# ============================================================================
|
|
|
|
| 64 |
MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
|
| 65 |
MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
|
| 66 |
TARGET_SR = 48000
|
| 67 |
+
LATENT_HZ = 25 # latent frames per second (48000 / 1920)
|
| 68 |
+
CHUNK_LATENT_MIN = 20 * LATENT_HZ # 500 frames (20s)
|
| 69 |
+
CHUNK_LATENT_TARGET = 30 * LATENT_HZ # 750 frames (30s)
|
| 70 |
+
CHUNK_LATENT_MAX = 40 * LATENT_HZ # 1000 frames (40s)
|
| 71 |
AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
|
| 72 |
|
| 73 |
# bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
|
|
|
|
| 782 |
return hs, mask
|
| 783 |
|
| 784 |
|
| 785 |
+
# ============================================================================
|
| 786 |
+
# LATENT CHUNKING (split long latents into ~30s training samples)
|
| 787 |
+
# ============================================================================
|
| 788 |
+
|
| 789 |
+
def _chunk_latents(latent: torch.Tensor) -> List[torch.Tensor]:
|
| 790 |
+
"""Split a [T, C] latent into ~30s chunks for faster training.
|
| 791 |
+
|
| 792 |
+
Uses energy-based boundary detection: finds the lowest-energy frame
|
| 793 |
+
within the 20-40s window around each cut point, avoiding cuts through
|
| 794 |
+
loud notes. Short files (<=40s) are returned as-is.
|
| 795 |
+
"""
|
| 796 |
+
T = latent.shape[0]
|
| 797 |
+
if T <= CHUNK_LATENT_MAX:
|
| 798 |
+
return [latent]
|
| 799 |
+
|
| 800 |
+
energy = latent.pow(2).mean(dim=-1) # [T] per-frame energy
|
| 801 |
+
|
| 802 |
+
chunks = []
|
| 803 |
+
pos = 0
|
| 804 |
+
while pos < T:
|
| 805 |
+
remaining = T - pos
|
| 806 |
+
if remaining <= CHUNK_LATENT_MAX:
|
| 807 |
+
if chunks and remaining < CHUNK_LATENT_MIN:
|
| 808 |
+
# Merge short tail into the previous chunk
|
| 809 |
+
chunks[-1] = latent[pos - chunks[-1].shape[0]:]
|
| 810 |
+
else:
|
| 811 |
+
chunks.append(latent[pos:])
|
| 812 |
+
break
|
| 813 |
+
|
| 814 |
+
search_start = pos + CHUNK_LATENT_MIN
|
| 815 |
+
search_end = min(pos + CHUNK_LATENT_MAX, T)
|
| 816 |
+
window = energy[search_start:search_end]
|
| 817 |
+
cut = search_start + window.argmin().item()
|
| 818 |
+
|
| 819 |
+
chunks.append(latent[pos:cut])
|
| 820 |
+
pos = cut
|
| 821 |
+
|
| 822 |
+
return chunks
|
| 823 |
+
|
| 824 |
+
|
| 825 |
# ============================================================================
|
| 826 |
# VAE TILED ENCODING
|
| 827 |
# ============================================================================
|
|
|
|
| 2192 |
del target_latents
|
| 2193 |
continue
|
| 2194 |
|
| 2195 |
+
# Auto-caption (once per file, shared across chunks)
|
|
|
|
|
|
|
|
|
|
| 2196 |
sidecar = _read_caption_sidecar(af)
|
| 2197 |
if sidecar is not None:
|
| 2198 |
caption = sidecar.get("caption", "") or af.stem
|
| 2199 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 2200 |
logger.info("[Caption] %s: using existing sidecar", af.name)
|
| 2201 |
else:
|
|
|
|
|
|
|
| 2202 |
if device == "cpu":
|
| 2203 |
analysis_mode = "faf"
|
| 2204 |
elif total <= 20:
|
|
|
|
| 2208 |
else:
|
| 2209 |
analysis_mode = "faf"
|
| 2210 |
|
|
|
|
| 2211 |
if i == 0:
|
| 2212 |
_MODE_DESC = {
|
| 2213 |
"faf": "fast, ~3s/file",
|
|
|
|
| 2215 |
"sas": "best quality, ~30s/file on GPU, slower on CPU",
|
| 2216 |
}
|
| 2217 |
logger.info(
|
| 2218 |
+
"[Analysis] Mode '%s' (%s) for %d files",
|
|
|
|
| 2219 |
analysis_mode, _MODE_DESC[analysis_mode], total,
|
| 2220 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2221 |
|
| 2222 |
try:
|
| 2223 |
logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
|
|
|
|
| 2232 |
logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
|
| 2233 |
caption = af.stem
|
| 2234 |
lyrics = "[Instrumental]"
|
|
|
|
| 2235 |
|
| 2236 |
with torch.no_grad():
|
| 2237 |
+
text_hs, text_mask = encode_text(text_enc, tokenizer, caption, device, dtype)
|
| 2238 |
lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
|
| 2239 |
|
| 2240 |
has_bad = any(
|
|
|
|
| 2243 |
)
|
| 2244 |
if has_bad:
|
| 2245 |
p1_failed += 1
|
| 2246 |
+
del target_latents, text_hs, text_mask, lyric_hs, lyric_mask
|
| 2247 |
continue
|
| 2248 |
|
| 2249 |
+
# Chunk latents into ~30s segments for faster training
|
| 2250 |
+
full_lat = target_latents.squeeze(0).cpu() # [T, C]
|
| 2251 |
+
T = full_lat.shape[0]
|
| 2252 |
+
chunks = _chunk_latents(full_lat)
|
| 2253 |
+
logger.info("[Chunk] %s: %d frames -> %d chunks", af.name, T, len(chunks))
|
| 2254 |
+
|
| 2255 |
+
text_hs_cpu = text_hs.cpu()
|
| 2256 |
+
text_mask_cpu = text_mask.cpu()
|
| 2257 |
+
lyric_hs_cpu = lyric_hs.cpu()
|
| 2258 |
+
lyric_mask_cpu = lyric_mask.cpu()
|
| 2259 |
+
silence_cpu = silence_lat.cpu()
|
| 2260 |
+
meta = {
|
| 2261 |
+
"audio_path": str(af),
|
| 2262 |
+
"filename": af.name,
|
| 2263 |
+
"caption": caption,
|
| 2264 |
+
"lyrics": lyrics,
|
| 2265 |
+
}
|
| 2266 |
+
|
| 2267 |
+
for ci, chunk_lat in enumerate(chunks):
|
| 2268 |
+
chunk_len = chunk_lat.shape[0]
|
| 2269 |
+
chunk_mask = torch.ones(chunk_len, dtype=dtype)
|
| 2270 |
+
tag = f"{stem}_chunk{ci}" if len(chunks) > 1 else stem
|
| 2271 |
+
tmp_path = out / f"{tag}.tmp.pt"
|
| 2272 |
+
torch.save({
|
| 2273 |
+
"target_latents": chunk_lat,
|
| 2274 |
+
"attention_mask": chunk_mask,
|
| 2275 |
+
"text_hidden_states": text_hs_cpu,
|
| 2276 |
+
"text_attention_mask": text_mask_cpu,
|
| 2277 |
+
"lyric_hidden_states": lyric_hs_cpu,
|
| 2278 |
+
"lyric_attention_mask": lyric_mask_cpu,
|
| 2279 |
+
"silence_latent": silence_cpu,
|
| 2280 |
+
"latent_length": chunk_len,
|
| 2281 |
+
"metadata": meta,
|
| 2282 |
+
}, tmp_path)
|
| 2283 |
+
intermediates.append(tmp_path)
|
| 2284 |
+
|
| 2285 |
+
del target_latents, full_lat, text_hs, text_mask, lyric_hs, lyric_mask
|
| 2286 |
|
| 2287 |
if progress_callback:
|
| 2288 |
+
progress_callback(i + 1, total, f"[Pass 1] {af.name} ({len(chunks)} chunks)")
|
| 2289 |
|
| 2290 |
except Exception as exc:
|
| 2291 |
p1_failed += 1
|
|
|
|
| 2373 |
_clear_gpu_cache(device)
|
| 2374 |
|
| 2375 |
failed = p1_failed + p2_failed
|
| 2376 |
+
return {"processed": processed, "failed": failed, "total": total,
|
| 2377 |
+
"chunks": len(intermediates), "output_dir": str(out)}
|
| 2378 |
|
| 2379 |
|
| 2380 |
# ============================================================================
|