ACE-Step-CPU / caption_fast.py
Nekochu's picture
fix: use librosa instead of torchaudio for VAD (torchcodec not installed), fix audios deprecation
17d39ba
"""Fast audio captioning: CLAP tags + Silero VAD + faster-whisper lyrics.
Provides mood/genre/instrument tagging via CLAP zero-shot classification,
speech detection via Silero VAD, and lyrics extraction via faster-whisper.
All models run on CPU. Total: ~3-5 min per file.
Usage:
from caption_fast import caption_audio
result = caption_audio("song.mp3")
# {"caption": "Pop, Energetic, Guitar, Melodic, Upbeat",
# "lyrics": "[Verse]\nSome lyrics here...",
# "bpm": 120, "key": "C major", "signature": "4/4",
# "tags": ["Pop", "Energetic", "Guitar", ...]}
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Tag list for CLAP zero-shot classification (from clap-interrogator)
TAGS = [
"Fast", "Slow", "Upbeat", "Downbeat", "Moderate",
"Happy", "Sad", "Energetic", "Relaxed", "Melancholic", "Uplifting",
"Aggressive", "Peaceful", "Romantic", "Dark", "Light", "Mysterious",
"Dreamy", "Somber", "Hopeful", "Gloomy", "Cheerful", "Reflective",
"Nostalgic", "Tense", "Calm",
"Piano", "Guitar", "Violin", "Drums", "Bass", "Synthesizer",
"Saxophone", "Trumpet", "Flute", "Cello", "Clarinet", "Harp",
"Percussion", "Organ", "Accordion", "Electronic", "Acoustic",
"Electric Guitar", "Acoustic Guitar", "Synth Pad", "Keyboards",
"Rock", "Pop", "Jazz", "Classical", "Electronic", "Folk", "Hip-Hop",
"Blues", "Ambient", "Country", "Reggae", "Funk", "Soul", "Metal",
"Dance", "Disco", "House", "Techno", "Trance", "Soundtrack", "World",
"Indie", "Alternative", "R&B", "EDM", "Chillwave", "Dubstep",
"Lo-fi Hip-Hop", "Drum and Bass", "Jazz Fusion", "Neo-Soul", "Trap",
"K-Pop", "J-Pop", "Reggaeton", "Punk", "Grunge",
"Bright", "Warm", "Smooth", "Distorted", "Clean", "Lo-fi",
"Layered", "Minimalist", "Cinematic", "Atmospheric", "Ethereal",
"Groovy", "Rhythmic", "Melodic", "Harmonic",
"Live", "Studio", "Instrumental",
]
_clap_model = None
_clap_processor = None
_whisper_model = None
_vad_model = None
def _load_clap():
global _clap_model, _clap_processor
if _clap_model is not None:
return _clap_model, _clap_processor
from transformers import ClapModel, ClapProcessor
logger.info("[CLAP] Loading laion/larger_clap_music...")
_clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music")
_clap_model = ClapModel.from_pretrained("laion/larger_clap_music")
_clap_model.eval()
logger.info("[CLAP] Ready (~780MB)")
return _clap_model, _clap_processor
def _load_whisper():
global _whisper_model
if _whisper_model is not None:
return _whisper_model
from faster_whisper import WhisperModel
logger.info("[Whisper] Loading large-v3-turbo (int8, CPU)...")
_whisper_model = WhisperModel(
"large-v3-turbo",
device="cpu",
compute_type="int8",
)
logger.info("[Whisper] Ready (~1.5GB)")
return _whisper_model
def _load_vad():
global _vad_model
if _vad_model is not None:
return _vad_model
import torch
logger.info("[VAD] Loading Silero VAD...")
_vad_model, _vad_utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=True,
trust_repo=True,
)
logger.info("[VAD] Ready (~2MB)")
return _vad_model
def unload_caption_models():
"""Free all captioning models from memory."""
global _clap_model, _clap_processor, _whisper_model, _vad_model
import gc
_clap_model = None
_clap_processor = None
_whisper_model = None
_vad_model = None
gc.collect()
logger.info("[Caption] All models unloaded")
def tag_audio(audio_path: str, top_n: int = 10) -> List[str]:
"""Get top-N CLAP tags for an audio file."""
import librosa
import torch
model, processor = _load_clap()
audio, sr = librosa.load(audio_path, sr=48000, mono=True)
inputs = processor(
text=TAGS,
audio=[audio],
sampling_rate=48000,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_audio.softmax(dim=-1)
top_probs, top_indices = probs.topk(top_n, dim=1)
return [TAGS[i] for i in top_indices[0].tolist()]
def detect_speech(audio_path: str, threshold: float = 5.0) -> bool:
"""Check if audio contains speech using Silero VAD.
Returns True if speech detected for more than `threshold` seconds.
"""
import torch
import librosa
vad = _load_vad()
y, sr = librosa.load(audio_path, sr=16000, mono=True)
wav = torch.from_numpy(y).unsqueeze(0)
speech_timestamps = []
window_size = 512
for i in range(0, wav.shape[1], window_size):
chunk = wav[0, i:i + window_size]
if len(chunk) < window_size:
break
prob = vad(chunk, 16000).item()
if prob > 0.5:
speech_timestamps.append(i / 16000)
speech_duration = len(speech_timestamps) * (window_size / 16000)
logger.info("[VAD] Speech: %.1fs detected in %s", speech_duration, os.path.basename(audio_path))
return speech_duration > threshold
def transcribe_lyrics(audio_path: str) -> str:
"""Extract lyrics from audio using faster-whisper."""
model = _load_whisper()
segments, info = model.transcribe(
audio_path,
language=None,
beam_size=5,
vad_filter=True,
)
lines = []
for segment in segments:
text = segment.text.strip()
if text:
lines.append(text)
lyrics = "\n".join(lines)
if not lyrics.strip():
return "[Instrumental]"
logger.info("[Whisper] Transcribed %d lines (lang=%s, prob=%.2f)",
len(lines), info.language, info.language_probability)
return lyrics
def get_bpm_key(audio_path: str) -> Dict[str, str]:
"""Get BPM and key via librosa."""
import librosa
import numpy as np
y, sr = librosa.load(audio_path, sr=None, mono=True)
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
bpm = int(round(float(tempo.item() if hasattr(tempo, 'item') else tempo)))
chroma = librosa.feature.chroma_cens(y=y, sr=sr)
chroma_avg = np.mean(chroma, axis=1)
keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
best_corr = -1
best_key = "C major"
for i in range(12):
maj_corr = float(np.corrcoef(np.roll(major_profile, i), chroma_avg)[0, 1])
min_corr = float(np.corrcoef(np.roll(minor_profile, i), chroma_avg)[0, 1])
if maj_corr > best_corr:
best_corr = maj_corr
best_key = f"{keys[i]} major"
if min_corr > best_corr:
best_corr = min_corr
best_key = f"{keys[i]} minor"
return {"bpm": str(bpm), "key": best_key, "signature": "4/4"}
def caption_audio(
audio_path: str,
top_n: int = 10,
extract_lyrics: bool = True,
speech_threshold: float = 5.0,
) -> Dict[str, str]:
"""Full fast captioning pipeline for one audio file.
Returns dict with: caption, lyrics, bpm, key, signature, tags
"""
fname = os.path.basename(audio_path)
logger.info("[Caption] Processing %s...", fname)
# 1. CLAP tags (mood, genre, instruments)
tags = tag_audio(audio_path, top_n=top_n)
caption = ", ".join(tags)
logger.info("[Caption] %s: tags=%s", fname, caption)
# 2. BPM + key via librosa
bpm_key = get_bpm_key(audio_path)
logger.info("[Caption] %s: BPM=%s, key=%s", fname, bpm_key["bpm"], bpm_key["key"])
# 3. Speech detection + lyrics
lyrics = "[Instrumental]"
if extract_lyrics:
has_speech = detect_speech(audio_path, threshold=speech_threshold)
if has_speech:
logger.info("[Caption] %s: speech detected, transcribing lyrics...", fname)
lyrics = transcribe_lyrics(audio_path)
else:
logger.info("[Caption] %s: no speech, marking instrumental", fname)
return {
"caption": caption,
"lyrics": lyrics,
"bpm": bpm_key["bpm"],
"key": bpm_key["key"],
"signature": bpm_key["signature"],
"tags": tags,
}