Spaces:
Running
Running
| """Fast audio captioning: CLAP tags + Silero VAD + faster-whisper lyrics. | |
| Provides mood/genre/instrument tagging via CLAP zero-shot classification, | |
| speech detection via Silero VAD, and lyrics extraction via faster-whisper. | |
| All models run on CPU. Total: ~3-5 min per file. | |
| Usage: | |
| from caption_fast import caption_audio | |
| result = caption_audio("song.mp3") | |
| # {"caption": "Pop, Energetic, Guitar, Melodic, Upbeat", | |
| # "lyrics": "[Verse]\nSome lyrics here...", | |
| # "bpm": 120, "key": "C major", "signature": "4/4", | |
| # "tags": ["Pop", "Energetic", "Guitar", ...]} | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| logger = logging.getLogger(__name__) | |
| # Tag list for CLAP zero-shot classification (from clap-interrogator) | |
| TAGS = [ | |
| "Fast", "Slow", "Upbeat", "Downbeat", "Moderate", | |
| "Happy", "Sad", "Energetic", "Relaxed", "Melancholic", "Uplifting", | |
| "Aggressive", "Peaceful", "Romantic", "Dark", "Light", "Mysterious", | |
| "Dreamy", "Somber", "Hopeful", "Gloomy", "Cheerful", "Reflective", | |
| "Nostalgic", "Tense", "Calm", | |
| "Piano", "Guitar", "Violin", "Drums", "Bass", "Synthesizer", | |
| "Saxophone", "Trumpet", "Flute", "Cello", "Clarinet", "Harp", | |
| "Percussion", "Organ", "Accordion", "Electronic", "Acoustic", | |
| "Electric Guitar", "Acoustic Guitar", "Synth Pad", "Keyboards", | |
| "Rock", "Pop", "Jazz", "Classical", "Electronic", "Folk", "Hip-Hop", | |
| "Blues", "Ambient", "Country", "Reggae", "Funk", "Soul", "Metal", | |
| "Dance", "Disco", "House", "Techno", "Trance", "Soundtrack", "World", | |
| "Indie", "Alternative", "R&B", "EDM", "Chillwave", "Dubstep", | |
| "Lo-fi Hip-Hop", "Drum and Bass", "Jazz Fusion", "Neo-Soul", "Trap", | |
| "K-Pop", "J-Pop", "Reggaeton", "Punk", "Grunge", | |
| "Bright", "Warm", "Smooth", "Distorted", "Clean", "Lo-fi", | |
| "Layered", "Minimalist", "Cinematic", "Atmospheric", "Ethereal", | |
| "Groovy", "Rhythmic", "Melodic", "Harmonic", | |
| "Live", "Studio", "Instrumental", | |
| ] | |
| _clap_model = None | |
| _clap_processor = None | |
| _whisper_model = None | |
| _vad_model = None | |
| def _load_clap(): | |
| global _clap_model, _clap_processor | |
| if _clap_model is not None: | |
| return _clap_model, _clap_processor | |
| from transformers import ClapModel, ClapProcessor | |
| logger.info("[CLAP] Loading laion/larger_clap_music...") | |
| _clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music") | |
| _clap_model = ClapModel.from_pretrained("laion/larger_clap_music") | |
| _clap_model.eval() | |
| logger.info("[CLAP] Ready (~780MB)") | |
| return _clap_model, _clap_processor | |
| def _load_whisper(): | |
| global _whisper_model | |
| if _whisper_model is not None: | |
| return _whisper_model | |
| from faster_whisper import WhisperModel | |
| logger.info("[Whisper] Loading large-v3-turbo (int8, CPU)...") | |
| _whisper_model = WhisperModel( | |
| "large-v3-turbo", | |
| device="cpu", | |
| compute_type="int8", | |
| ) | |
| logger.info("[Whisper] Ready (~1.5GB)") | |
| return _whisper_model | |
| def _load_vad(): | |
| global _vad_model | |
| if _vad_model is not None: | |
| return _vad_model | |
| import torch | |
| logger.info("[VAD] Loading Silero VAD...") | |
| _vad_model, _vad_utils = torch.hub.load( | |
| repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| onnx=True, | |
| trust_repo=True, | |
| ) | |
| logger.info("[VAD] Ready (~2MB)") | |
| return _vad_model | |
| def unload_caption_models(): | |
| """Free all captioning models from memory.""" | |
| global _clap_model, _clap_processor, _whisper_model, _vad_model | |
| import gc | |
| _clap_model = None | |
| _clap_processor = None | |
| _whisper_model = None | |
| _vad_model = None | |
| gc.collect() | |
| logger.info("[Caption] All models unloaded") | |
| def tag_audio(audio_path: str, top_n: int = 10) -> List[str]: | |
| """Get top-N CLAP tags for an audio file.""" | |
| import librosa | |
| import torch | |
| model, processor = _load_clap() | |
| audio, sr = librosa.load(audio_path, sr=48000, mono=True) | |
| inputs = processor( | |
| text=TAGS, | |
| audio=[audio], | |
| sampling_rate=48000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_audio.softmax(dim=-1) | |
| top_probs, top_indices = probs.topk(top_n, dim=1) | |
| return [TAGS[i] for i in top_indices[0].tolist()] | |
| def detect_speech(audio_path: str, threshold: float = 5.0) -> bool: | |
| """Check if audio contains speech using Silero VAD. | |
| Returns True if speech detected for more than `threshold` seconds. | |
| """ | |
| import torch | |
| import librosa | |
| vad = _load_vad() | |
| y, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| wav = torch.from_numpy(y).unsqueeze(0) | |
| speech_timestamps = [] | |
| window_size = 512 | |
| for i in range(0, wav.shape[1], window_size): | |
| chunk = wav[0, i:i + window_size] | |
| if len(chunk) < window_size: | |
| break | |
| prob = vad(chunk, 16000).item() | |
| if prob > 0.5: | |
| speech_timestamps.append(i / 16000) | |
| speech_duration = len(speech_timestamps) * (window_size / 16000) | |
| logger.info("[VAD] Speech: %.1fs detected in %s", speech_duration, os.path.basename(audio_path)) | |
| return speech_duration > threshold | |
| def transcribe_lyrics(audio_path: str) -> str: | |
| """Extract lyrics from audio using faster-whisper.""" | |
| model = _load_whisper() | |
| segments, info = model.transcribe( | |
| audio_path, | |
| language=None, | |
| beam_size=5, | |
| vad_filter=True, | |
| ) | |
| lines = [] | |
| for segment in segments: | |
| text = segment.text.strip() | |
| if text: | |
| lines.append(text) | |
| lyrics = "\n".join(lines) | |
| if not lyrics.strip(): | |
| return "[Instrumental]" | |
| logger.info("[Whisper] Transcribed %d lines (lang=%s, prob=%.2f)", | |
| len(lines), info.language, info.language_probability) | |
| return lyrics | |
| def get_bpm_key(audio_path: str) -> Dict[str, str]: | |
| """Get BPM and key via librosa.""" | |
| import librosa | |
| import numpy as np | |
| y, sr = librosa.load(audio_path, sr=None, mono=True) | |
| tempo, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| bpm = int(round(float(tempo.item() if hasattr(tempo, 'item') else tempo))) | |
| chroma = librosa.feature.chroma_cens(y=y, sr=sr) | |
| chroma_avg = np.mean(chroma, axis=1) | |
| keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] | |
| major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]) | |
| minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]) | |
| best_corr = -1 | |
| best_key = "C major" | |
| for i in range(12): | |
| maj_corr = float(np.corrcoef(np.roll(major_profile, i), chroma_avg)[0, 1]) | |
| min_corr = float(np.corrcoef(np.roll(minor_profile, i), chroma_avg)[0, 1]) | |
| if maj_corr > best_corr: | |
| best_corr = maj_corr | |
| best_key = f"{keys[i]} major" | |
| if min_corr > best_corr: | |
| best_corr = min_corr | |
| best_key = f"{keys[i]} minor" | |
| return {"bpm": str(bpm), "key": best_key, "signature": "4/4"} | |
| def caption_audio( | |
| audio_path: str, | |
| top_n: int = 10, | |
| extract_lyrics: bool = True, | |
| speech_threshold: float = 5.0, | |
| ) -> Dict[str, str]: | |
| """Full fast captioning pipeline for one audio file. | |
| Returns dict with: caption, lyrics, bpm, key, signature, tags | |
| """ | |
| fname = os.path.basename(audio_path) | |
| logger.info("[Caption] Processing %s...", fname) | |
| # 1. CLAP tags (mood, genre, instruments) | |
| tags = tag_audio(audio_path, top_n=top_n) | |
| caption = ", ".join(tags) | |
| logger.info("[Caption] %s: tags=%s", fname, caption) | |
| # 2. BPM + key via librosa | |
| bpm_key = get_bpm_key(audio_path) | |
| logger.info("[Caption] %s: BPM=%s, key=%s", fname, bpm_key["bpm"], bpm_key["key"]) | |
| # 3. Speech detection + lyrics | |
| lyrics = "[Instrumental]" | |
| if extract_lyrics: | |
| has_speech = detect_speech(audio_path, threshold=speech_threshold) | |
| if has_speech: | |
| logger.info("[Caption] %s: speech detected, transcribing lyrics...", fname) | |
| lyrics = transcribe_lyrics(audio_path) | |
| else: | |
| logger.info("[Caption] %s: no speech, marking instrumental", fname) | |
| return { | |
| "caption": caption, | |
| "lyrics": lyrics, | |
| "bpm": bpm_key["bpm"], | |
| "key": bpm_key["key"], | |
| "signature": bpm_key["signature"], | |
| "tags": tags, | |
| } | |