strandtest / app.py
rmoxon's picture
Upload app.py
8aeafb6 verified
import os
import tempfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel
try:
from transformers import ClapModel, ClapProcessor
CLAP_AVAILABLE = True
CLAP_METHOD = "transformers"
except ImportError as e1:
CLAP_AVAILABLE = False
CLAP_METHOD = None
# Check MERT availability
try:
from transformers import AutoModel, Wav2Vec2FeatureExtractor
MERT_AVAILABLE = True
MERT_METHOD = "transformers"
except ImportError as e2:
MERT_AVAILABLE = False
MERT_METHOD = None
import torch
import torchaudio
from PIL import Image
import requests
import numpy as np
import io
import logging
import librosa
import soundfile as sf
import scipy.signal
# Set environment to disable librosa caching
os.environ['LIBROSA_CACHE_DIR'] = '/tmp'
os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp'
# Disable librosa caching completely to avoid the __o_fold error
os.environ['LIBROSA_CACHE_LEVEL'] = '0'
os.environ['LIBROSA_CACHE_COMPRESS'] = '0'
# Check pitch-aware model availability
try:
# Try to use a simpler pitch-aware approach with librosa
import librosa
PITCH_AWARE_AVAILABLE = True
PITCH_METHOD = "librosa_chroma"
except ImportError:
PITCH_AWARE_AVAILABLE = False
PITCH_METHOD = None
# Fusion configuration
FUSION_MODE = os.environ.get('FUSION_MODE', 'VECTOR_CONCAT') # VECTOR_CONCAT or SCORE_FUSION
FUSION_ALPHA = float(os.environ.get('FUSION_ALPHA', '0.6')) # Alpha for score fusion
ENABLE_PITCH_FUSION = os.environ.get('ENABLE_PITCH_FUSION', 'false').lower() == 'true' and PITCH_AWARE_AVAILABLE
ENABLE_MERT_FUSION = os.environ.get('ENABLE_MERT_FUSION', 'false').lower() == 'true' and MERT_AVAILABLE
# Audio processing limits
MAX_AUDIO_DURATION_SEC = int(os.environ.get('MAX_AUDIO_DURATION_SEC', '600')) # 10 minutes default
# Set up cache directories
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache')
os.makedirs(cache_dir, exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['TORCH_HOME'] = cache_dir
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CLIP Service", version="1.0.0")
def sanitize_for_json(embedding):
"""Sanitize embedding for JSON serialization"""
if isinstance(embedding, np.ndarray):
# Ensure finite values and convert to list
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
elif isinstance(embedding, list):
# Check each element for finite values
return [float(x) if np.isfinite(x) else 0.0 for x in embedding]
else:
return embedding
# Log CLAP, MERT, and pitch-aware availability after logger is initialized
logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}")
logger.info(f"MERT availability: {MERT_AVAILABLE}, method: {MERT_METHOD}")
logger.info(f"Pitch-aware availability: {PITCH_AWARE_AVAILABLE}, method: {PITCH_METHOD}")
logger.info(f"Pitch fusion enabled: {ENABLE_PITCH_FUSION}")
logger.info(f"MERT fusion enabled: {ENABLE_MERT_FUSION}")
logger.info(f"Fusion mode: {FUSION_MODE}")
if FUSION_MODE == 'SCORE_FUSION':
logger.info(f"Fusion alpha: {FUSION_ALPHA}")
class CLIPService:
def __init__(self):
logger.info("Loading CLIP model...")
self.clap_model = None
self.clap_processor = None
self.mert_model = None
self.mert_processor = None
# Simple in-memory cache for pitch features keyed by audio hash
# Using a dict avoids the "unhashable type: numpy.ndarray" error we hit with functools.lru_cache
self.pitch_feature_cache: dict[int, np.ndarray] = {}
try:
# Use CPU for Hugging Face free tier
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load CLIP model with explicit cache directory
logger.info("Loading CLIP model from HuggingFace...")
self.clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLIP processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
logger.info(f"Using fast processor: {use_fast}")
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"CLIP model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load CLIP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLIP model loading failed: {str(e)}")
def _load_clap_model(self):
"""Load CLAP model on demand"""
if not CLAP_AVAILABLE:
raise RuntimeError("CLAP model not available - transformers version may not support CLAP")
if self.clap_model is None:
logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
try:
if CLAP_METHOD == "transformers":
logger.info("Loading CLAP model from HuggingFace...")
self.clap_model = ClapModel.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLAP processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
self.clap_processor = ClapProcessor.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
except Exception as e:
logger.error(f"Failed to load CLAP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLAP model loading failed: {str(e)}")
def _load_mert_model(self):
"""Load MERT model on demand"""
if not MERT_AVAILABLE:
raise RuntimeError("MERT model not available - transformers version may not support MERT")
if self.mert_model is None:
logger.info(f"Loading MERT model on demand using {MERT_METHOD} method...")
try:
logger.info("Loading MERT model from HuggingFace...")
self.mert_model = AutoModel.from_pretrained(
"m-a-p/MERT-v1-95M",
trust_remote_code=True,
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
# Guard against missing encoder (stale HF cache issue)
if not hasattr(self.mert_model, "encoder"):
raise RuntimeError("MERT weights not loaded - clear HF cache and retry")
logger.info("Loading MERT processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(
"m-a-p/MERT-v1-95M",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"MERT model loaded successfully on {self.device} using {MERT_METHOD}")
except Exception as e:
logger.error(f"Failed to load MERT model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"MERT model loading failed: {str(e)}")
def extract_pitch_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Extract pitch-aware features using numpy/scipy (avoiding all librosa caching issues) with a lightweight dict cache"""
cache_key = hash(audio_array.tobytes())
# Fast path: return cached result if we've already seen this audio chunk
if cache_key in self.pitch_feature_cache:
return self.pitch_feature_cache[cache_key]
# Slow path: compute fresh features and store them
features = self._extract_pitch_features_impl(audio_array, sample_rate)
self.pitch_feature_cache[cache_key] = features
return features
def _extract_pitch_features_impl(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Implementation of pitch feature extraction (cached)"""
try:
# Use pure numpy/scipy implementations to avoid all librosa caching issues
features = []
# Extract basic audio features using numpy/scipy only
try:
# Basic spectral features using numpy FFT
spectral_features = self._extract_spectral_features_numpy(audio_array, sample_rate)
features.extend(spectral_features)
logger.info("✓ Spectral features extracted (numpy)")
except Exception as e:
logger.warning(f"Spectral feature extraction failed: {e}, using zeros")
features.extend(np.zeros(20)) # 20 spectral features
try:
# Basic temporal features
temporal_features = self._extract_temporal_features_numpy(audio_array, sample_rate)
features.extend(temporal_features)
logger.info("✓ Temporal features extracted (numpy)")
except Exception as e:
logger.warning(f"Temporal feature extraction failed: {e}, using zeros")
features.extend(np.zeros(15)) # 15 temporal features
try:
# Basic frequency domain features
frequency_features = self._extract_frequency_features_numpy(audio_array, sample_rate)
features.extend(frequency_features)
logger.info("✓ Frequency features extracted (numpy)")
except Exception as e:
logger.warning(f"Frequency feature extraction failed: {e}, using zeros")
features.extend(np.zeros(25)) # 25 frequency features
# Simple tempo estimation (fallback)
try:
tempo = self._estimate_tempo_numpy(audio_array, sample_rate)
features.append(tempo)
logger.info(f"✓ Tempo estimated: {tempo:.1f} BPM")
except Exception as e:
logger.warning(f"Tempo estimation failed: {e}, using default")
features.append(120.0) # Default BPM
# Convert to numpy array and check for NaN/inf values
pitch_features = np.array(features, dtype=np.float32)
# Replace any NaN or inf values with 0
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure we have the expected 85 dimensions
if len(pitch_features) < 85:
# Pad with zeros if needed
padding = np.zeros(85 - len(pitch_features), dtype=np.float32)
pitch_features = np.concatenate([pitch_features, padding])
elif len(pitch_features) > 85:
# Truncate if too long
pitch_features = pitch_features[:85]
# L2 normalize
norm = np.linalg.norm(pitch_features)
if norm > 0:
pitch_features = pitch_features / norm
else:
# If norm is 0, create a small non-zero vector
pitch_features = np.ones(85, dtype=np.float32) * 0.001
pitch_features = pitch_features / np.linalg.norm(pitch_features)
# Final check for finite values
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
# Result is automatically cached by LRU decorator
logger.info(f"Extracted pitch features: {len(pitch_features)} dimensions")
return pitch_features
except Exception as e:
logger.error(f"Error extracting pitch features: {str(e)}")
# Return normalized zero vector if extraction fails
zero_features = np.ones(85, dtype=np.float32) * 0.001
return zero_features / np.linalg.norm(zero_features)
def _extract_spectral_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract spectral features using only numpy (no librosa)"""
# Compute FFT
fft = np.fft.fft(audio_array)
magnitude = np.abs(fft)
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate)
# Only use positive frequencies
pos_freqs = freqs[:len(freqs)//2]
pos_magnitude = magnitude[:len(magnitude)//2]
# Spectral centroid
spectral_centroid = np.sum(pos_freqs * pos_magnitude) / np.sum(pos_magnitude)
# Spectral rolloff (95% of energy)
cumsum = np.cumsum(pos_magnitude)
rolloff_idx = np.where(cumsum >= 0.95 * cumsum[-1])[0]
spectral_rolloff = pos_freqs[rolloff_idx[0]] if len(rolloff_idx) > 0 else 0
# Spectral spread
spectral_spread = np.sqrt(np.sum(((pos_freqs - spectral_centroid) ** 2) * pos_magnitude) / np.sum(pos_magnitude))
# Zero crossing rate
zero_crossings = np.where(np.diff(np.sign(audio_array)))[0]
zcr = len(zero_crossings) / len(audio_array)
# RMS energy
rms = np.sqrt(np.mean(audio_array ** 2))
# Basic spectral features
features = [
spectral_centroid / sample_rate, # Normalize
spectral_rolloff / sample_rate, # Normalize
spectral_spread / sample_rate, # Normalize
zcr,
rms,
np.max(pos_magnitude),
np.mean(pos_magnitude),
np.std(pos_magnitude),
np.sum(pos_magnitude),
np.var(pos_magnitude)
]
# Add frequency band energies (10 bands)
band_energies = []
n_bands = 10
for i in range(n_bands):
start_idx = i * len(pos_magnitude) // n_bands
end_idx = (i + 1) * len(pos_magnitude) // n_bands
band_energy = np.sum(pos_magnitude[start_idx:end_idx])
band_energies.append(band_energy)
features.extend(band_energies)
return features
def _extract_temporal_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract temporal features using only numpy"""
# Basic statistics
mean_val = np.mean(audio_array)
std_val = np.std(audio_array)
skew_val = np.mean(((audio_array - mean_val) / std_val) ** 3)
kurtosis_val = np.mean(((audio_array - mean_val) / std_val) ** 4)
# Energy-based features
energy = np.sum(audio_array ** 2)
power = energy / len(audio_array)
# Envelope features
envelope = np.abs(audio_array)
envelope_mean = np.mean(envelope)
envelope_std = np.std(envelope)
# Attack/decay characteristics
envelope_diff = np.diff(envelope)
attack_time = np.mean(envelope_diff[envelope_diff > 0])
decay_time = np.mean(-envelope_diff[envelope_diff < 0])
# Peak characteristics
peaks = np.where(np.diff(np.sign(np.diff(audio_array))) < 0)[0]
peak_density = len(peaks) / len(audio_array)
# Dynamics
dynamic_range = np.max(envelope) - np.min(envelope)
features = [
mean_val,
std_val,
skew_val,
kurtosis_val,
energy,
power,
envelope_mean,
envelope_std,
attack_time if np.isfinite(attack_time) else 0.0,
decay_time if np.isfinite(decay_time) else 0.0,
peak_density,
dynamic_range,
np.percentile(envelope, 25),
np.percentile(envelope, 75),
np.median(envelope)
]
return features
def _extract_frequency_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract frequency domain features using only numpy"""
# Short-time analysis
hop_length = 512
frame_length = 2048
features = []
# Process in overlapping windows
n_frames = (len(audio_array) - frame_length) // hop_length + 1
frame_features = []
for i in range(0, min(n_frames, 50)): # Limit to 50 frames for performance
start = i * hop_length
end = start + frame_length
if end > len(audio_array):
break
frame = audio_array[start:end]
# Apply window
window = np.hanning(len(frame))
windowed_frame = frame * window
# FFT
fft = np.fft.fft(windowed_frame)
magnitude = np.abs(fft)
# Frame-level features
frame_energy = np.sum(magnitude ** 2)
frame_centroid = np.sum(np.arange(len(magnitude)) * magnitude) / np.sum(magnitude)
frame_features.append([frame_energy, frame_centroid])
if frame_features:
frame_features = np.array(frame_features)
# Aggregate features across frames
features.extend([
np.mean(frame_features[:, 0]), # Mean energy
np.std(frame_features[:, 0]), # Energy std
np.mean(frame_features[:, 1]), # Mean centroid
np.std(frame_features[:, 1]), # Centroid std
np.max(frame_features[:, 0]), # Max energy
np.min(frame_features[:, 0]), # Min energy
np.median(frame_features[:, 0]), # Median energy
np.var(frame_features[:, 0]), # Energy variance
np.mean(np.diff(frame_features[:, 0])), # Energy delta
np.std(np.diff(frame_features[:, 0])) # Energy delta std
])
else:
features.extend(np.zeros(10))
# Add some basic harmonic features
fft_full = np.fft.fft(audio_array)
magnitude_full = np.abs(fft_full)
# Find fundamental frequency (simple peak detection)
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate)
pos_freqs = freqs[:len(freqs)//2]
pos_magnitude = magnitude_full[:len(magnitude_full)//2]
# Harmonic features
peak_idx = np.argmax(pos_magnitude)
fundamental_freq = pos_freqs[peak_idx]
# Harmonic ratios (simple approximation)
harmonic_features = []
for harmonic in [2, 3, 4, 5]:
target_freq = fundamental_freq * harmonic
if target_freq < sample_rate / 2:
target_idx = np.argmin(np.abs(pos_freqs - target_freq))
harmonic_ratio = pos_magnitude[target_idx] / pos_magnitude[peak_idx]
harmonic_features.append(harmonic_ratio)
else:
harmonic_features.append(0.0)
features.extend(harmonic_features)
# Add more frequency-based features
features.extend([
fundamental_freq / sample_rate, # Normalized fundamental
np.sum(pos_magnitude > np.mean(pos_magnitude)), # Number of significant peaks
np.sum(pos_magnitude) / len(pos_magnitude), # Average magnitude
np.std(pos_magnitude), # Magnitude std
np.max(pos_magnitude) / np.mean(pos_magnitude), # Peak prominence
np.sum(pos_magnitude[:len(pos_magnitude)//4]), # Low freq energy
np.sum(pos_magnitude[len(pos_magnitude)//4:len(pos_magnitude)//2]), # Mid freq energy
np.sum(pos_magnitude[len(pos_magnitude)//2:3*len(pos_magnitude)//4]), # High freq energy
np.sum(pos_magnitude[3*len(pos_magnitude)//4:]), # Very high freq energy
np.corrcoef(pos_magnitude[:-1], pos_magnitude[1:])[0,1] if len(pos_magnitude) > 1 else 0.0, # Spectral autocorr
np.sum(np.diff(pos_magnitude) > 0) / len(pos_magnitude) # Spectral flux
])
return features
def _estimate_tempo_numpy(self, audio_array: np.ndarray, sample_rate: int) -> float:
"""Simple tempo estimation using numpy only"""
try:
# Simple envelope-based tempo detection
hop_length = 512
frame_length = 2048
# Calculate envelope
envelope = np.abs(audio_array)
# Downsample envelope
n_frames = (len(envelope) - frame_length) // hop_length + 1
envelope_frames = []
for i in range(0, min(n_frames, 1000)): # Limit frames
start = i * hop_length
end = start + frame_length
if end > len(envelope):
break
frame_energy = np.mean(envelope[start:end])
envelope_frames.append(frame_energy)
if len(envelope_frames) < 10:
return 120.0
envelope_frames = np.array(envelope_frames)
# Find peaks in envelope
from scipy.signal import find_peaks
peaks, _ = find_peaks(envelope_frames, height=np.mean(envelope_frames))
if len(peaks) > 2:
# Calculate intervals between peaks
peak_intervals = np.diff(peaks) * hop_length / sample_rate
# Filter reasonable intervals (0.2 to 2 seconds)
valid_intervals = peak_intervals[(peak_intervals > 0.2) & (peak_intervals < 2.0)]
if len(valid_intervals) > 0:
avg_interval = np.mean(valid_intervals)
tempo = 60.0 / avg_interval
# Constrain to reasonable range
tempo = max(60, min(200, tempo))
return tempo
return 120.0
except Exception as e:
logger.warning(f"Numpy tempo estimation failed: {e}")
return 120.0
def extract_mert_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Extract MERT features using the Music Understanding Model"""
try:
# Load MERT model on demand
self._load_mert_model()
# Resample to 24kHz for MERT processing if needed
if sample_rate != 24000:
logger.info(f"Resampling from {sample_rate}Hz to 24kHz for MERT...")
from scipy.signal import resample
target_length = int(len(audio_array) * 24000 / sample_rate)
audio_array_24k = resample(audio_array, target_length)
else:
audio_array_24k = audio_array
# Critical: Convert to float32 for MERT (community checkpoints expect float32)
audio_array_24k = audio_array_24k.astype(np.float32, copy=False)
# Use multi-window strategy like CLAP to capture variation
total_duration = len(audio_array_24k) / 24000
logger.info(f"Processing MERT features for {len(audio_array_24k)} samples at 24kHz ({total_duration:.1f} seconds)")
# Multi-window sampling approach (5-second windows as per MERT paper)
hop_seconds = 5 # Move window every 5 seconds
win_seconds = 5 # Each window is 5 seconds (MERT positional encodings trained on 5s)
hop_samples = hop_seconds * 24000
win_samples = win_seconds * 24000
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for MERT")
samples = [audio_array_24k]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for MERT: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * 24000
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array_24k):
window = audio_array_24k[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array_24k[start_sample:]
if len(window) >= win_samples // 2: # At least 5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} MERT windows covering entire track")
# Process each sample and collect embeddings
embeddings = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing MERT sample {i+1}/{len(samples)}")
# Process with MERT processor
inputs = self.mert_processor(
sample,
sampling_rate=24000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.mert_model(**inputs, output_hidden_states=True)
# Get all hidden states (13 layers for MERT-v1-95M)
all_hidden_states = outputs.hidden_states # Shape: [13, batch_size, time_steps, 768]
# Time reduction: average across time dimension
# Shape: [13, batch_size, 768]
time_reduced_states = torch.stack([layer.mean(dim=1) for layer in all_hidden_states])
# Take the mean across all layers to get a single 768-dim representation
# This follows the typical approach for utterance-level representations
window_embedding = time_reduced_states.mean(dim=0).squeeze() # Shape: [768]
embeddings.append(window_embedding.cpu().numpy().flatten())
# Average the raw embeddings from all samples (like CLAP)
final_embedding = np.mean(embeddings, axis=0)
# Ensure finite values
mert_features = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# L2 normalize
norm = np.linalg.norm(mert_features)
if norm > 0:
mert_features = mert_features / norm
else:
# If norm is 0, create a small non-zero vector
mert_features = np.ones(768, dtype=np.float32) * 0.001
mert_features = mert_features / np.linalg.norm(mert_features)
# Final check for finite values
mert_features = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0)
logger.info(f"Extracted MERT features: {len(mert_features)} dimensions (norm: {np.linalg.norm(mert_features):.6f})")
return mert_features
except Exception as e:
logger.error(f"Error extracting MERT features: {str(e)}")
# Return normalized zero vector if extraction fails
zero_features = np.ones(768, dtype=np.float32) * 0.001
return zero_features / np.linalg.norm(zero_features)
def fuse_embeddings(self, clap_embedding: np.ndarray, pitch_features: np.ndarray = None, mert_features: np.ndarray = None) -> np.ndarray:
"""Fuse CLAP, pitch-aware, and/or MERT features based on fusion mode"""
if FUSION_MODE == "VECTOR_CONCAT":
# Normalize before and after concatenation (critical for proper similarity computation)
embeddings_to_fuse = []
# Step 1: Normalize each embedding individually (clap_u = normalize(clap))
# Always include CLAP
clap_u = clap_embedding / np.linalg.norm(clap_embedding)
embeddings_to_fuse.append(clap_u)
# Add MERT if available
if mert_features is not None:
mert_u = mert_features / np.linalg.norm(mert_features)
embeddings_to_fuse.append(mert_u)
# Add pitch features if available
if pitch_features is not None:
pitch_u = pitch_features / np.linalg.norm(pitch_features)
embeddings_to_fuse.append(pitch_u)
# Step 2: Concatenate normalized embeddings (fused = cat([clap_u, mert_u]))
fused = np.concatenate(embeddings_to_fuse)
# Step 3: Normalize the concatenated result (fused = normalize(fused))
# Critical: Without this, similarity is inflated by √2 (0.83 → 0.93)
fused_norm = np.linalg.norm(fused)
if fused_norm > 0:
fused = fused / fused_norm
else:
logger.warning("Zero norm in fused embedding - creating fallback vector")
fused = np.ones_like(fused) * 0.001
fused = fused / np.linalg.norm(fused)
# Verify norm is 1.0 (assert abs(fused.norm() - 1.0) < 1e-6)
final_norm = np.linalg.norm(fused)
if abs(final_norm - 1.0) > 1e-6:
logger.warning(f"Fused embedding norm is {final_norm:.8f}, not 1.0 - normalization issue!")
# Ensure finite values for JSON serialization
fused = np.nan_to_num(fused, nan=0.0, posinf=0.0, neginf=0.0)
return fused
else:
# SCORE_FUSION: return embeddings separately for weighted similarity calculation
result = {}
result["clap"] = np.nan_to_num(clap_embedding, nan=0.0, posinf=0.0, neginf=0.0)
if mert_features is not None:
result["mert"] = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0)
if pitch_features is not None:
result["pitch"] = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
return result
def is_supported_format(self, image_url: str) -> bool:
"""Check if image format is supported by PIL/CLIP"""
unsupported_extensions = ['.avif', '.heic', '.heif']
url_lower = image_url.lower()
return not any(url_lower.endswith(ext) for ext in unsupported_extensions)
def detect_image_format(self, content: bytes) -> str:
"""Detect actual image format from content"""
try:
# Check for AVIF signature
if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]:
return 'AVIF'
# Check for HEIC signature
elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]):
return 'HEIC'
# Check for WebP
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]:
return 'WebP'
# Check for PNG
elif content.startswith(b'\x89PNG\r\n\x1a\n'):
return 'PNG'
# Check for JPEG
elif content.startswith(b'\xff\xd8\xff'):
return 'JPEG'
# Check for GIF
elif content.startswith((b'GIF87a', b'GIF89a')):
return 'GIF'
else:
return 'Unknown'
except:
return 'Unknown'
def encode_image(self, image_url: str) -> list:
try:
logger.info(f"Processing image: {image_url}")
# Quick URL-based format check first
if not self.is_supported_format(image_url):
logger.warning(f"Unsupported format detected from URL: {image_url}")
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)")
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'})
response.raise_for_status()
# Detect actual format from content
image_format = self.detect_image_format(response.content)
logger.info(f"Detected image format: {image_format}")
if image_format in ['AVIF', 'HEIC']:
logger.warning(f"Unsupported format detected: {image_format} for {image_url}")
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}")
try:
image = Image.open(io.BytesIO(response.content))
except Exception as e:
logger.error(f"PIL cannot open image {image_url}: {str(e)}")
if "cannot identify image file" in str(e).lower():
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format")
raise
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Resize image if too large to avoid memory issues
max_size = 224 # CLIP's expected input size
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Use the working method directly (Method 3) to avoid fallback overhead
inputs = self.clip_processor(
images=[image],
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.clip_model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Ensure safe values for JSON serialization
embedding = image_features.cpu().numpy().flatten()
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
except Exception as e:
logger.error(f"Error encoding image {image_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
def encode_text(self, text: str) -> list:
try:
logger.info(f"Processing text: {text[:50]}...")
inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
text_features = self.clip_model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Ensure safe values for JSON serialization
embedding = text_features.cpu().numpy().flatten()
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
except Exception as e:
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
def encode_audio(self, audio_url: str) -> list:
try:
logger.info(f"Processing audio: {audio_url}")
# Load CLAP model on demand
self._load_clap_model()
# Pitch fusion is enabled if pitch-aware features are available
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE:
logger.info("Pitch fusion enabled with librosa features")
# Download audio file
response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'})
response.raise_for_status()
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
try:
# Load audio using soundfile first, then resample with librosa if needed
# This avoids the caching issues with librosa.load
logger.info("Loading audio with soundfile...")
audio_array, original_sr = sf.read(tmp_path)
# Convert to mono if needed
if len(audio_array.shape) > 1:
logger.info("Converting stereo to mono")
audio_array = audio_array.mean(axis=1)
else:
logger.info("Audio is already mono")
# Resample to 48kHz for CLAP processing
if original_sr != 48000:
logger.info(f"Resampling from {original_sr}Hz to 48kHz...")
# Use scipy for resampling to avoid librosa caching issues
from scipy.signal import resample
target_length = int(len(audio_array) * 48000 / original_sr)
audio_array_48k = resample(audio_array, target_length)
else:
audio_array_48k = audio_array
# Critical: Convert to float32 for CLAP (community checkpoints expect float32)
audio_array_48k = audio_array_48k.astype(np.float32, copy=False)
total_duration = len(audio_array_48k) / 48000
logger.info(f"Audio loaded: {len(audio_array_48k)} samples at 48kHz ({total_duration:.1f} seconds)")
# Check audio duration limit and truncate both arrays consistently
if total_duration > MAX_AUDIO_DURATION_SEC:
logger.warning(f"Audio duration {total_duration:.1f}s exceeds limit {MAX_AUDIO_DURATION_SEC}s, truncating...")
max_samples_48k = MAX_AUDIO_DURATION_SEC * 48000
max_samples_orig = MAX_AUDIO_DURATION_SEC * original_sr
# Truncate both arrays to keep MERT and Pitch processing within limits
audio_array_48k = audio_array_48k[:max_samples_48k]
audio_array = audio_array[:max_samples_orig]
total_duration = MAX_AUDIO_DURATION_SEC
logger.info(f"Truncated both arrays to {total_duration:.1f} seconds (48kHz: {len(audio_array_48k)} samples, {original_sr}Hz: {len(audio_array)} samples)")
# Process with CLAP (10s windows, 5s hops)
clap_embedding = self._process_clap_embeddings(audio_array_48k, total_duration)
# Initialize additional features
pitch_features = None
mert_features = None
# Process with MERT features if enabled
if ENABLE_MERT_FUSION and MERT_AVAILABLE:
try:
logger.info("Processing with MERT features for fusion...")
mert_features = self.extract_mert_features(audio_array, original_sr)
except Exception as e:
logger.error(f"MERT feature processing failed: {str(e)}")
mert_features = None
# Process with pitch-aware features if enabled (can run alongside MERT)
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE:
try:
logger.info("Processing with pitch-aware features for fusion...")
pitch_features = self._process_pitch_features(audio_array, original_sr, total_duration)
except Exception as e:
logger.error(f"Pitch feature processing failed: {str(e)}")
pitch_features = None
# Handle fusion based on what features are available
if mert_features is not None or pitch_features is not None:
if FUSION_MODE == "VECTOR_CONCAT":
final_embedding = self.fuse_embeddings(clap_embedding, pitch_features, mert_features)
if mert_features is not None:
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + MERT 768)")
elif pitch_features is not None:
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + Pitch 85)")
return final_embedding.tolist()
else:
# SCORE_FUSION: return embeddings separately
logger.info("Score fusion mode: returning separate embeddings")
result = self.fuse_embeddings(clap_embedding, pitch_features, mert_features)
result["fusion_mode"] = "SCORE_FUSION"
result["fusion_alpha"] = FUSION_ALPHA
# Convert to lists for JSON serialization
for key in result:
if isinstance(result[key], np.ndarray):
result[key] = result[key].tolist()
return result
else:
# CLAP-only processing (backwards compatible)
logger.info("Using CLAP-only processing")
return clap_embedding.tolist()
finally:
# Clean up temp file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Error encoding audio {audio_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}")
def _process_clap_embeddings(self, audio_array: np.ndarray, total_duration: float) -> np.ndarray:
"""Process audio with CLAP using 10s windows and 5s hops"""
# Multi-window sampling approach for better discrimination
hop_seconds = 5 # Move window every 5 seconds
win_seconds = 10 # Each window is 10 seconds
hop_samples = hop_seconds * 48000
win_samples = win_seconds * 48000
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for CLAP")
samples = [audio_array]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for CLAP: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * 48000
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array):
window = audio_array[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array[start_sample:]
if len(window) >= win_samples // 2: # At least 5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} CLAP windows covering entire track")
# Process each sample and collect embeddings
embeddings = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing CLAP sample {i+1}/{len(samples)}")
# Process with CLAP
inputs = self.clap_processor(
audios=sample,
sampling_rate=48000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
audio_features = self.clap_model.get_audio_features(**inputs)
window_vec = audio_features.squeeze(0) # 512-D, no L2
embeddings.append(window_vec.cpu().numpy().flatten())
# Average the raw embeddings from all samples
final_embedding = np.mean(embeddings, axis=0)
# Ensure finite values before normalization
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure proper L2 normalization for cosine similarity
norm = np.linalg.norm(final_embedding)
if norm > 0:
final_embedding = final_embedding / norm
else:
logger.warning("Zero norm CLAP embedding detected")
# Create a small normalized vector instead of zero
final_embedding = np.ones_like(final_embedding) * 0.001
final_embedding = final_embedding / np.linalg.norm(final_embedding)
# Final safety check for JSON serialization
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# Verify normalization
final_norm = np.linalg.norm(final_embedding)
logger.info(f"Final CLAP embedding norm: {final_norm:.6f} (should be ~1.0)")
return final_embedding
def _process_pitch_features(self, audio_array: np.ndarray, original_sr: int, total_duration: float) -> np.ndarray:
"""Process audio with pitch-aware features using 5s windows and 2s hops"""
# Pitch feature processing parameters
hop_seconds = 2 # Move window every 2 seconds
win_seconds = 5 # Each window is 5 seconds
hop_samples = hop_seconds * original_sr
win_samples = win_seconds * original_sr
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for pitch features")
samples = [audio_array]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for pitch features: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * original_sr
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array):
window = audio_array[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array[start_sample:]
if len(window) >= win_samples // 2: # At least 2.5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} pitch feature windows covering entire track")
# Process each sample and collect features
feature_vectors = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing pitch features sample {i+1}/{len(samples)}")
# Extract pitch features
pitch_features = self.extract_pitch_features(sample, original_sr)
feature_vectors.append(pitch_features)
# Average the raw feature vectors from all samples
final_features = np.mean(feature_vectors, axis=0)
# Ensure finite values before normalization
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure proper L2 normalization for cosine similarity
norm = np.linalg.norm(final_features)
if norm > 0:
final_features = final_features / norm
else:
logger.warning("Zero norm pitch features detected")
# Create a small normalized vector instead of zero
final_features = np.ones_like(final_features) * 0.001
final_features = final_features / np.linalg.norm(final_features)
# Final safety check for JSON serialization
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0)
# Verify normalization
final_norm = np.linalg.norm(final_features)
logger.info(f"Final pitch features norm: {final_norm:.6f} (should be ~1.0)")
return final_features
# Initialize service with error handling
logger.info("Initializing CLIP service...")
try:
clip_service = CLIPService()
logger.info("CLIP service initialized successfully!")
except Exception as e:
logger.error(f"Failed to initialize CLIP service: {str(e)}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
# For now, we'll let the app start but service calls will fail gracefully
clip_service = None
class ImageRequest(BaseModel):
image_url: str
class TextRequest(BaseModel):
text: str
class AudioRequest(BaseModel):
audio_url: str
@app.get("/")
async def root():
return {
"message": "CLIP Service API",
"version": "1.0.0",
"model": "clip-vit-large-patch14",
"endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"],
"status": "ready" if clip_service else "error"
}
@app.post("/encode/image")
async def encode_image(request: ImageRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_image(request.image_url)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.post("/encode/text")
async def encode_text(request: TextRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_text(request.text)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.post("/encode/audio")
async def encode_audio(request: AudioRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLAP service not available")
if not CLAP_AVAILABLE:
raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version")
embedding = clip_service.encode_audio(request.audio_url)
# Handle both single embedding and fusion mode results
if isinstance(embedding, dict):
# Score fusion mode - sanitize all embeddings
safe_embedding = {}
dimensions = {}
for key, value in embedding.items():
if key in ["clap", "pitch", "mert"] and isinstance(value, list):
safe_embedding[key] = sanitize_for_json(value)
dimensions[key] = len(safe_embedding[key])
else:
safe_embedding[key] = value
return {"embedding": safe_embedding, "dimensions": dimensions}
else:
# Single embedding (CLAP-only or concatenated)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.get("/health")
async def health_check():
if not clip_service:
return {
"status": "unhealthy",
"model": "clip-vit-large-patch14",
"error": "Service failed to initialize"
}
health_info = {
"status": "healthy",
"models": {
"clip": "clip-vit-large-patch14",
"clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available",
"mert": f"MERT-v1-95M (lazy loaded, method: {MERT_METHOD})" if MERT_AVAILABLE else "not available"
},
"device": clip_service.device,
"service": "ready",
"cache_dir": cache_dir
}
# Add pitch-aware information
if PITCH_AWARE_AVAILABLE:
health_info["models"]["pitch_aware"] = f"librosa features ({PITCH_METHOD})"
# Add fusion information
fusion_enabled = ENABLE_PITCH_FUSION or ENABLE_MERT_FUSION
if fusion_enabled:
health_info["fusion"] = {
"enabled": True,
"mode": FUSION_MODE,
"pitch_fusion_enabled": ENABLE_PITCH_FUSION,
"mert_fusion_enabled": ENABLE_MERT_FUSION,
"pitch_aware_available": PITCH_AWARE_AVAILABLE,
"mert_available": MERT_AVAILABLE
}
if FUSION_MODE == "SCORE_FUSION":
health_info["fusion"]["alpha"] = FUSION_ALPHA
else:
health_info["fusion"] = {
"enabled": False,
"mode": "CLAP_ONLY"
}
return health_info
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)