SyncAI / src /stem_separator.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""LALAL.AI API wrapper for audio stem separation."""
import os
import shutil
import time
from pathlib import Path
from typing import Optional
import requests
API_BASE = "https://www.lalal.ai/api/v1"
DATA_DIR = Path(__file__).parent.parent / "data"
# Stems we need for the pipeline
STEMS_TO_EXTRACT = ["vocals", "drum"]
# Map LALAL.AI track labels to our file naming convention
LABEL_TO_FILENAME = {"vocals": "vocals.wav", "drum": "drums.wav"}
def _get_api_key() -> str:
key = os.environ.get("LALAL_KEY")
if not key:
raise RuntimeError(
"LALAL_KEY environment variable not set. "
"Set it locally or as a HuggingFace Space secret."
)
return key
def _headers(api_key: str) -> dict:
return {"X-License-Key": api_key}
def _next_run_dir(song_dir: Path) -> Path:
"""Find the next available run directory (run_001, run_002, ...)."""
existing = sorted(song_dir.glob("run_*"))
next_num = 1
for d in existing:
try:
num = int(d.name.split("_")[1])
next_num = max(next_num, num + 1)
except (IndexError, ValueError):
continue
return song_dir / f"run_{next_num:03d}"
def _upload(audio_path: Path, api_key: str) -> str:
"""Upload audio file to LALAL.AI. Returns source_id."""
with open(audio_path, "rb") as f:
resp = requests.post(
f"{API_BASE}/upload/",
headers={
**_headers(api_key),
"Content-Disposition": f'attachment; filename="{audio_path.name}"',
},
data=f,
)
resp.raise_for_status()
data = resp.json()
source_id = data["id"]
print(f" Uploaded {audio_path.name} → source_id={source_id} "
f"(duration: {data['duration']:.1f}s)")
return source_id
def _split_stem(source_id: str, stem: str, api_key: str) -> str:
"""Start a stem separation task. Returns task_id."""
# Andromeda is best for vocals but doesn't support all stems — use auto for others
splitter = "andromeda" if stem == "vocals" else None
resp = requests.post(
f"{API_BASE}/split/stem_separator/",
headers=_headers(api_key),
json={
"source_id": source_id,
"presets": {
"stem": stem,
"splitter": splitter,
"dereverb_enabled": False,
"encoder_format": "wav",
"extraction_level": "deep_extraction",
},
},
)
resp.raise_for_status()
data = resp.json()
task_id = data["task_id"]
print(f" Split task started: stem={stem}, task_id={task_id}")
return task_id
def _poll_tasks(task_ids: list[str], api_key: str, poll_interval: float = 5.0) -> dict:
"""Poll tasks until all complete. Returns {task_id: result_data}."""
pending = set(task_ids)
results = {}
while pending:
resp = requests.post(
f"{API_BASE}/check/",
headers=_headers(api_key),
json={"task_ids": list(pending)},
)
resp.raise_for_status()
data = resp.json().get("result", resp.json())
for task_id, info in data.items():
status = info.get("status")
if status == "success":
results[task_id] = info
pending.discard(task_id)
print(f" Task {task_id}: complete")
elif status == "progress":
print(f" Task {task_id}: {info.get('progress', 0)}%")
elif status == "error":
error = info.get("error", {})
raise RuntimeError(
f"LALAL.AI task {task_id} failed: "
f"{error.get('detail', 'unknown error')} "
f"(code: {error.get('code')})"
)
elif status == "cancelled":
raise RuntimeError(f"LALAL.AI task {task_id} was cancelled")
elif status == "server_error":
raise RuntimeError(
f"LALAL.AI server error for task {task_id}: "
f"{info.get('error', 'unknown')}"
)
if pending:
time.sleep(poll_interval)
return results
def _download_track(url: str, output_path: Path) -> None:
"""Download a track from LALAL.AI CDN."""
resp = requests.get(url, stream=True)
resp.raise_for_status()
with open(output_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
print(f" Downloaded → {output_path.name} ({output_path.stat().st_size / 1024:.0f} KB)")
def _delete_source(source_id: str, api_key: str) -> None:
"""Delete uploaded source file from LALAL.AI servers."""
try:
requests.post(
f"{API_BASE}/delete/",
headers=_headers(api_key),
json={"source_id": source_id},
)
print(f" Cleaned up remote source {source_id}")
except Exception:
pass # non-critical
def separate_stems(
audio_path: str | Path,
output_dir: Optional[str | Path] = None,
) -> dict[str, Path]:
"""Separate an audio file into vocals and drums using LALAL.AI.
Creates a new run directory for each invocation so multiple runs
on the same song don't overwrite each other.
Args:
audio_path: Path to the input audio file (mp3/wav) from input/.
output_dir: Directory to save stems. If None, auto-creates
data/<song>/run_NNN/stems/.
Returns:
Dict mapping stem names to their file paths.
Keys: "drums", "vocals", "run_dir"
"""
audio_path = Path(audio_path)
song_name = audio_path.stem
song_dir = DATA_DIR / song_name
api_key = _get_api_key()
if output_dir is None:
run_dir = _next_run_dir(song_dir)
output_dir = run_dir / "stems"
else:
output_dir = Path(output_dir)
run_dir = output_dir.parent
output_dir.mkdir(parents=True, exist_ok=True)
# Copy original song into song directory (shared across runs)
song_copy = song_dir / audio_path.name
if not song_copy.exists():
shutil.copy2(audio_path, song_copy)
# 1. Upload
print("Stem separation (LALAL.AI):")
source_id = _upload(audio_path, api_key)
# 2. Start split tasks for each stem
task_to_stem = {}
for stem in STEMS_TO_EXTRACT:
task_id = _split_stem(source_id, stem, api_key)
task_to_stem[task_id] = stem
# 3. Poll until all tasks complete
results = _poll_tasks(list(task_to_stem.keys()), api_key)
# 4. Download the separated stem tracks
stem_paths = {"run_dir": run_dir}
for task_id, result_data in results.items():
stem = task_to_stem[task_id]
filename = LABEL_TO_FILENAME[stem]
tracks = result_data.get("result", {}).get("tracks", [])
# Find the "stem" track (not the "back"/inverse track)
stem_track = next((t for t in tracks if t["type"] == "stem"), None)
if stem_track is None:
raise RuntimeError(f"No stem track found in result for {stem}")
output_path = output_dir / filename
_download_track(stem_track["url"], output_path)
# Map to our naming: "drum" API stem → "drums" key
key = "drums" if stem == "drum" else stem
stem_paths[key] = output_path
# 5. Cleanup remote files
_delete_source(source_id, api_key)
return stem_paths
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.stem_separator <audio_file>")
sys.exit(1)
result = separate_stems(sys.argv[1])
print(f"Run directory: {result['run_dir']}")
for name, path in result.items():
if name != "run_dir":
print(f" {name}: {path}")