Spaces:
Paused
Paused
import os | |
import tempfile | |
from pathlib import Path | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import CLIPProcessor, CLIPModel | |
try: | |
from transformers import ClapModel, ClapProcessor | |
CLAP_AVAILABLE = True | |
CLAP_METHOD = "transformers" | |
except ImportError as e1: | |
CLAP_AVAILABLE = False | |
CLAP_METHOD = None | |
# Check MERT availability | |
try: | |
from transformers import AutoModel, Wav2Vec2FeatureExtractor | |
MERT_AVAILABLE = True | |
MERT_METHOD = "transformers" | |
except ImportError as e2: | |
MERT_AVAILABLE = False | |
MERT_METHOD = None | |
import torch | |
import torchaudio | |
from PIL import Image | |
import requests | |
import numpy as np | |
import io | |
import logging | |
import librosa | |
import soundfile as sf | |
import scipy.signal | |
# Set environment to disable librosa caching | |
os.environ['LIBROSA_CACHE_DIR'] = '/tmp' | |
os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp' | |
# Disable librosa caching completely to avoid the __o_fold error | |
os.environ['LIBROSA_CACHE_LEVEL'] = '0' | |
os.environ['LIBROSA_CACHE_COMPRESS'] = '0' | |
# Check pitch-aware model availability | |
try: | |
# Try to use a simpler pitch-aware approach with librosa | |
import librosa | |
PITCH_AWARE_AVAILABLE = True | |
PITCH_METHOD = "librosa_chroma" | |
except ImportError: | |
PITCH_AWARE_AVAILABLE = False | |
PITCH_METHOD = None | |
# Fusion configuration | |
FUSION_MODE = os.environ.get('FUSION_MODE', 'VECTOR_CONCAT') # VECTOR_CONCAT or SCORE_FUSION | |
FUSION_ALPHA = float(os.environ.get('FUSION_ALPHA', '0.6')) # Alpha for score fusion | |
ENABLE_PITCH_FUSION = os.environ.get('ENABLE_PITCH_FUSION', 'false').lower() == 'true' and PITCH_AWARE_AVAILABLE | |
ENABLE_MERT_FUSION = os.environ.get('ENABLE_MERT_FUSION', 'false').lower() == 'true' and MERT_AVAILABLE | |
# Audio processing limits | |
MAX_AUDIO_DURATION_SEC = int(os.environ.get('MAX_AUDIO_DURATION_SEC', '600')) # 10 minutes default | |
# Set up cache directories | |
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache') | |
os.makedirs(cache_dir, exist_ok=True) | |
os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
os.environ['HF_HOME'] = cache_dir | |
os.environ['TORCH_HOME'] = cache_dir | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="CLIP Service", version="1.0.0") | |
def sanitize_for_json(embedding): | |
"""Sanitize embedding for JSON serialization""" | |
if isinstance(embedding, np.ndarray): | |
# Ensure finite values and convert to list | |
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
return embedding.tolist() | |
elif isinstance(embedding, list): | |
# Check each element for finite values | |
return [float(x) if np.isfinite(x) else 0.0 for x in embedding] | |
else: | |
return embedding | |
# Log CLAP, MERT, and pitch-aware availability after logger is initialized | |
logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}") | |
logger.info(f"MERT availability: {MERT_AVAILABLE}, method: {MERT_METHOD}") | |
logger.info(f"Pitch-aware availability: {PITCH_AWARE_AVAILABLE}, method: {PITCH_METHOD}") | |
logger.info(f"Pitch fusion enabled: {ENABLE_PITCH_FUSION}") | |
logger.info(f"MERT fusion enabled: {ENABLE_MERT_FUSION}") | |
logger.info(f"Fusion mode: {FUSION_MODE}") | |
if FUSION_MODE == 'SCORE_FUSION': | |
logger.info(f"Fusion alpha: {FUSION_ALPHA}") | |
class CLIPService: | |
def __init__(self): | |
logger.info("Loading CLIP model...") | |
self.clap_model = None | |
self.clap_processor = None | |
self.mert_model = None | |
self.mert_processor = None | |
# Simple in-memory cache for pitch features keyed by audio hash | |
# Using a dict avoids the "unhashable type: numpy.ndarray" error we hit with functools.lru_cache | |
self.pitch_feature_cache: dict[int, np.ndarray] = {} | |
try: | |
# Use CPU for Hugging Face free tier | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Load CLIP model with explicit cache directory | |
logger.info("Loading CLIP model from HuggingFace...") | |
self.clip_model = CLIPModel.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=cache_dir, | |
local_files_only=False | |
).to(self.device) | |
logger.info("Loading CLIP processor...") | |
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' | |
logger.info(f"Using fast processor: {use_fast}") | |
self.clip_processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=cache_dir, | |
local_files_only=False, | |
use_fast=use_fast | |
) | |
logger.info(f"CLIP model loaded successfully on {self.device}") | |
except Exception as e: | |
logger.error(f"Failed to load CLIP model: {str(e)}") | |
logger.error(f"Error type: {type(e).__name__}") | |
raise RuntimeError(f"CLIP model loading failed: {str(e)}") | |
def _load_clap_model(self): | |
"""Load CLAP model on demand""" | |
if not CLAP_AVAILABLE: | |
raise RuntimeError("CLAP model not available - transformers version may not support CLAP") | |
if self.clap_model is None: | |
logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...") | |
try: | |
if CLAP_METHOD == "transformers": | |
logger.info("Loading CLAP model from HuggingFace...") | |
self.clap_model = ClapModel.from_pretrained( | |
"laion/clap-htsat-unfused", | |
cache_dir=cache_dir, | |
local_files_only=False | |
).to(self.device) | |
logger.info("Loading CLAP processor...") | |
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' | |
self.clap_processor = ClapProcessor.from_pretrained( | |
"laion/clap-htsat-unfused", | |
cache_dir=cache_dir, | |
local_files_only=False, | |
use_fast=use_fast | |
) | |
logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}") | |
except Exception as e: | |
logger.error(f"Failed to load CLAP model: {str(e)}") | |
logger.error(f"Error type: {type(e).__name__}") | |
raise RuntimeError(f"CLAP model loading failed: {str(e)}") | |
def _load_mert_model(self): | |
"""Load MERT model on demand""" | |
if not MERT_AVAILABLE: | |
raise RuntimeError("MERT model not available - transformers version may not support MERT") | |
if self.mert_model is None: | |
logger.info(f"Loading MERT model on demand using {MERT_METHOD} method...") | |
try: | |
logger.info("Loading MERT model from HuggingFace...") | |
self.mert_model = AutoModel.from_pretrained( | |
"m-a-p/MERT-v1-95M", | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
local_files_only=False | |
).to(self.device) | |
# Guard against missing encoder (stale HF cache issue) | |
if not hasattr(self.mert_model, "encoder"): | |
raise RuntimeError("MERT weights not loaded - clear HF cache and retry") | |
logger.info("Loading MERT processor...") | |
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' | |
self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained( | |
"m-a-p/MERT-v1-95M", | |
cache_dir=cache_dir, | |
local_files_only=False, | |
use_fast=use_fast | |
) | |
logger.info(f"MERT model loaded successfully on {self.device} using {MERT_METHOD}") | |
except Exception as e: | |
logger.error(f"Failed to load MERT model: {str(e)}") | |
logger.error(f"Error type: {type(e).__name__}") | |
raise RuntimeError(f"MERT model loading failed: {str(e)}") | |
def extract_pitch_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: | |
"""Extract pitch-aware features using numpy/scipy (avoiding all librosa caching issues) with a lightweight dict cache""" | |
cache_key = hash(audio_array.tobytes()) | |
# Fast path: return cached result if we've already seen this audio chunk | |
if cache_key in self.pitch_feature_cache: | |
return self.pitch_feature_cache[cache_key] | |
# Slow path: compute fresh features and store them | |
features = self._extract_pitch_features_impl(audio_array, sample_rate) | |
self.pitch_feature_cache[cache_key] = features | |
return features | |
def _extract_pitch_features_impl(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: | |
"""Implementation of pitch feature extraction (cached)""" | |
try: | |
# Use pure numpy/scipy implementations to avoid all librosa caching issues | |
features = [] | |
# Extract basic audio features using numpy/scipy only | |
try: | |
# Basic spectral features using numpy FFT | |
spectral_features = self._extract_spectral_features_numpy(audio_array, sample_rate) | |
features.extend(spectral_features) | |
logger.info("✓ Spectral features extracted (numpy)") | |
except Exception as e: | |
logger.warning(f"Spectral feature extraction failed: {e}, using zeros") | |
features.extend(np.zeros(20)) # 20 spectral features | |
try: | |
# Basic temporal features | |
temporal_features = self._extract_temporal_features_numpy(audio_array, sample_rate) | |
features.extend(temporal_features) | |
logger.info("✓ Temporal features extracted (numpy)") | |
except Exception as e: | |
logger.warning(f"Temporal feature extraction failed: {e}, using zeros") | |
features.extend(np.zeros(15)) # 15 temporal features | |
try: | |
# Basic frequency domain features | |
frequency_features = self._extract_frequency_features_numpy(audio_array, sample_rate) | |
features.extend(frequency_features) | |
logger.info("✓ Frequency features extracted (numpy)") | |
except Exception as e: | |
logger.warning(f"Frequency feature extraction failed: {e}, using zeros") | |
features.extend(np.zeros(25)) # 25 frequency features | |
# Simple tempo estimation (fallback) | |
try: | |
tempo = self._estimate_tempo_numpy(audio_array, sample_rate) | |
features.append(tempo) | |
logger.info(f"✓ Tempo estimated: {tempo:.1f} BPM") | |
except Exception as e: | |
logger.warning(f"Tempo estimation failed: {e}, using default") | |
features.append(120.0) # Default BPM | |
# Convert to numpy array and check for NaN/inf values | |
pitch_features = np.array(features, dtype=np.float32) | |
# Replace any NaN or inf values with 0 | |
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) | |
# Ensure we have the expected 85 dimensions | |
if len(pitch_features) < 85: | |
# Pad with zeros if needed | |
padding = np.zeros(85 - len(pitch_features), dtype=np.float32) | |
pitch_features = np.concatenate([pitch_features, padding]) | |
elif len(pitch_features) > 85: | |
# Truncate if too long | |
pitch_features = pitch_features[:85] | |
# L2 normalize | |
norm = np.linalg.norm(pitch_features) | |
if norm > 0: | |
pitch_features = pitch_features / norm | |
else: | |
# If norm is 0, create a small non-zero vector | |
pitch_features = np.ones(85, dtype=np.float32) * 0.001 | |
pitch_features = pitch_features / np.linalg.norm(pitch_features) | |
# Final check for finite values | |
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) | |
# Result is automatically cached by LRU decorator | |
logger.info(f"Extracted pitch features: {len(pitch_features)} dimensions") | |
return pitch_features | |
except Exception as e: | |
logger.error(f"Error extracting pitch features: {str(e)}") | |
# Return normalized zero vector if extraction fails | |
zero_features = np.ones(85, dtype=np.float32) * 0.001 | |
return zero_features / np.linalg.norm(zero_features) | |
def _extract_spectral_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: | |
"""Extract spectral features using only numpy (no librosa)""" | |
# Compute FFT | |
fft = np.fft.fft(audio_array) | |
magnitude = np.abs(fft) | |
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate) | |
# Only use positive frequencies | |
pos_freqs = freqs[:len(freqs)//2] | |
pos_magnitude = magnitude[:len(magnitude)//2] | |
# Spectral centroid | |
spectral_centroid = np.sum(pos_freqs * pos_magnitude) / np.sum(pos_magnitude) | |
# Spectral rolloff (95% of energy) | |
cumsum = np.cumsum(pos_magnitude) | |
rolloff_idx = np.where(cumsum >= 0.95 * cumsum[-1])[0] | |
spectral_rolloff = pos_freqs[rolloff_idx[0]] if len(rolloff_idx) > 0 else 0 | |
# Spectral spread | |
spectral_spread = np.sqrt(np.sum(((pos_freqs - spectral_centroid) ** 2) * pos_magnitude) / np.sum(pos_magnitude)) | |
# Zero crossing rate | |
zero_crossings = np.where(np.diff(np.sign(audio_array)))[0] | |
zcr = len(zero_crossings) / len(audio_array) | |
# RMS energy | |
rms = np.sqrt(np.mean(audio_array ** 2)) | |
# Basic spectral features | |
features = [ | |
spectral_centroid / sample_rate, # Normalize | |
spectral_rolloff / sample_rate, # Normalize | |
spectral_spread / sample_rate, # Normalize | |
zcr, | |
rms, | |
np.max(pos_magnitude), | |
np.mean(pos_magnitude), | |
np.std(pos_magnitude), | |
np.sum(pos_magnitude), | |
np.var(pos_magnitude) | |
] | |
# Add frequency band energies (10 bands) | |
band_energies = [] | |
n_bands = 10 | |
for i in range(n_bands): | |
start_idx = i * len(pos_magnitude) // n_bands | |
end_idx = (i + 1) * len(pos_magnitude) // n_bands | |
band_energy = np.sum(pos_magnitude[start_idx:end_idx]) | |
band_energies.append(band_energy) | |
features.extend(band_energies) | |
return features | |
def _extract_temporal_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: | |
"""Extract temporal features using only numpy""" | |
# Basic statistics | |
mean_val = np.mean(audio_array) | |
std_val = np.std(audio_array) | |
skew_val = np.mean(((audio_array - mean_val) / std_val) ** 3) | |
kurtosis_val = np.mean(((audio_array - mean_val) / std_val) ** 4) | |
# Energy-based features | |
energy = np.sum(audio_array ** 2) | |
power = energy / len(audio_array) | |
# Envelope features | |
envelope = np.abs(audio_array) | |
envelope_mean = np.mean(envelope) | |
envelope_std = np.std(envelope) | |
# Attack/decay characteristics | |
envelope_diff = np.diff(envelope) | |
attack_time = np.mean(envelope_diff[envelope_diff > 0]) | |
decay_time = np.mean(-envelope_diff[envelope_diff < 0]) | |
# Peak characteristics | |
peaks = np.where(np.diff(np.sign(np.diff(audio_array))) < 0)[0] | |
peak_density = len(peaks) / len(audio_array) | |
# Dynamics | |
dynamic_range = np.max(envelope) - np.min(envelope) | |
features = [ | |
mean_val, | |
std_val, | |
skew_val, | |
kurtosis_val, | |
energy, | |
power, | |
envelope_mean, | |
envelope_std, | |
attack_time if np.isfinite(attack_time) else 0.0, | |
decay_time if np.isfinite(decay_time) else 0.0, | |
peak_density, | |
dynamic_range, | |
np.percentile(envelope, 25), | |
np.percentile(envelope, 75), | |
np.median(envelope) | |
] | |
return features | |
def _extract_frequency_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: | |
"""Extract frequency domain features using only numpy""" | |
# Short-time analysis | |
hop_length = 512 | |
frame_length = 2048 | |
features = [] | |
# Process in overlapping windows | |
n_frames = (len(audio_array) - frame_length) // hop_length + 1 | |
frame_features = [] | |
for i in range(0, min(n_frames, 50)): # Limit to 50 frames for performance | |
start = i * hop_length | |
end = start + frame_length | |
if end > len(audio_array): | |
break | |
frame = audio_array[start:end] | |
# Apply window | |
window = np.hanning(len(frame)) | |
windowed_frame = frame * window | |
# FFT | |
fft = np.fft.fft(windowed_frame) | |
magnitude = np.abs(fft) | |
# Frame-level features | |
frame_energy = np.sum(magnitude ** 2) | |
frame_centroid = np.sum(np.arange(len(magnitude)) * magnitude) / np.sum(magnitude) | |
frame_features.append([frame_energy, frame_centroid]) | |
if frame_features: | |
frame_features = np.array(frame_features) | |
# Aggregate features across frames | |
features.extend([ | |
np.mean(frame_features[:, 0]), # Mean energy | |
np.std(frame_features[:, 0]), # Energy std | |
np.mean(frame_features[:, 1]), # Mean centroid | |
np.std(frame_features[:, 1]), # Centroid std | |
np.max(frame_features[:, 0]), # Max energy | |
np.min(frame_features[:, 0]), # Min energy | |
np.median(frame_features[:, 0]), # Median energy | |
np.var(frame_features[:, 0]), # Energy variance | |
np.mean(np.diff(frame_features[:, 0])), # Energy delta | |
np.std(np.diff(frame_features[:, 0])) # Energy delta std | |
]) | |
else: | |
features.extend(np.zeros(10)) | |
# Add some basic harmonic features | |
fft_full = np.fft.fft(audio_array) | |
magnitude_full = np.abs(fft_full) | |
# Find fundamental frequency (simple peak detection) | |
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate) | |
pos_freqs = freqs[:len(freqs)//2] | |
pos_magnitude = magnitude_full[:len(magnitude_full)//2] | |
# Harmonic features | |
peak_idx = np.argmax(pos_magnitude) | |
fundamental_freq = pos_freqs[peak_idx] | |
# Harmonic ratios (simple approximation) | |
harmonic_features = [] | |
for harmonic in [2, 3, 4, 5]: | |
target_freq = fundamental_freq * harmonic | |
if target_freq < sample_rate / 2: | |
target_idx = np.argmin(np.abs(pos_freqs - target_freq)) | |
harmonic_ratio = pos_magnitude[target_idx] / pos_magnitude[peak_idx] | |
harmonic_features.append(harmonic_ratio) | |
else: | |
harmonic_features.append(0.0) | |
features.extend(harmonic_features) | |
# Add more frequency-based features | |
features.extend([ | |
fundamental_freq / sample_rate, # Normalized fundamental | |
np.sum(pos_magnitude > np.mean(pos_magnitude)), # Number of significant peaks | |
np.sum(pos_magnitude) / len(pos_magnitude), # Average magnitude | |
np.std(pos_magnitude), # Magnitude std | |
np.max(pos_magnitude) / np.mean(pos_magnitude), # Peak prominence | |
np.sum(pos_magnitude[:len(pos_magnitude)//4]), # Low freq energy | |
np.sum(pos_magnitude[len(pos_magnitude)//4:len(pos_magnitude)//2]), # Mid freq energy | |
np.sum(pos_magnitude[len(pos_magnitude)//2:3*len(pos_magnitude)//4]), # High freq energy | |
np.sum(pos_magnitude[3*len(pos_magnitude)//4:]), # Very high freq energy | |
np.corrcoef(pos_magnitude[:-1], pos_magnitude[1:])[0,1] if len(pos_magnitude) > 1 else 0.0, # Spectral autocorr | |
np.sum(np.diff(pos_magnitude) > 0) / len(pos_magnitude) # Spectral flux | |
]) | |
return features | |
def _estimate_tempo_numpy(self, audio_array: np.ndarray, sample_rate: int) -> float: | |
"""Simple tempo estimation using numpy only""" | |
try: | |
# Simple envelope-based tempo detection | |
hop_length = 512 | |
frame_length = 2048 | |
# Calculate envelope | |
envelope = np.abs(audio_array) | |
# Downsample envelope | |
n_frames = (len(envelope) - frame_length) // hop_length + 1 | |
envelope_frames = [] | |
for i in range(0, min(n_frames, 1000)): # Limit frames | |
start = i * hop_length | |
end = start + frame_length | |
if end > len(envelope): | |
break | |
frame_energy = np.mean(envelope[start:end]) | |
envelope_frames.append(frame_energy) | |
if len(envelope_frames) < 10: | |
return 120.0 | |
envelope_frames = np.array(envelope_frames) | |
# Find peaks in envelope | |
from scipy.signal import find_peaks | |
peaks, _ = find_peaks(envelope_frames, height=np.mean(envelope_frames)) | |
if len(peaks) > 2: | |
# Calculate intervals between peaks | |
peak_intervals = np.diff(peaks) * hop_length / sample_rate | |
# Filter reasonable intervals (0.2 to 2 seconds) | |
valid_intervals = peak_intervals[(peak_intervals > 0.2) & (peak_intervals < 2.0)] | |
if len(valid_intervals) > 0: | |
avg_interval = np.mean(valid_intervals) | |
tempo = 60.0 / avg_interval | |
# Constrain to reasonable range | |
tempo = max(60, min(200, tempo)) | |
return tempo | |
return 120.0 | |
except Exception as e: | |
logger.warning(f"Numpy tempo estimation failed: {e}") | |
return 120.0 | |
def extract_mert_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: | |
"""Extract MERT features using the Music Understanding Model""" | |
try: | |
# Load MERT model on demand | |
self._load_mert_model() | |
# Resample to 24kHz for MERT processing if needed | |
if sample_rate != 24000: | |
logger.info(f"Resampling from {sample_rate}Hz to 24kHz for MERT...") | |
from scipy.signal import resample | |
target_length = int(len(audio_array) * 24000 / sample_rate) | |
audio_array_24k = resample(audio_array, target_length) | |
else: | |
audio_array_24k = audio_array | |
# Critical: Convert to float32 for MERT (community checkpoints expect float32) | |
audio_array_24k = audio_array_24k.astype(np.float32, copy=False) | |
# Use multi-window strategy like CLAP to capture variation | |
total_duration = len(audio_array_24k) / 24000 | |
logger.info(f"Processing MERT features for {len(audio_array_24k)} samples at 24kHz ({total_duration:.1f} seconds)") | |
# Multi-window sampling approach (5-second windows as per MERT paper) | |
hop_seconds = 5 # Move window every 5 seconds | |
win_seconds = 5 # Each window is 5 seconds (MERT positional encodings trained on 5s) | |
hop_samples = hop_seconds * 24000 | |
win_samples = win_seconds * 24000 | |
samples = [] | |
if total_duration <= win_seconds: | |
# Short audio: use the entire thing | |
logger.info("Short audio: using entire file for MERT") | |
samples = [audio_array_24k] | |
else: | |
# Multi-window sampling: overlapping windows across entire track | |
logger.info("Multi-window sampling for MERT: extracting overlapping windows") | |
max_offset = int(total_duration - win_seconds) + 1 | |
for t in range(0, max_offset, hop_seconds): | |
start_sample = t * 24000 | |
end_sample = start_sample + win_samples | |
# Ensure we don't go beyond the audio length | |
if end_sample <= len(audio_array_24k): | |
window = audio_array_24k[start_sample:end_sample] | |
samples.append(window) | |
else: | |
# Last window: take what's available | |
window = audio_array_24k[start_sample:] | |
if len(window) >= win_samples // 2: # At least 5 seconds | |
# Pad to full window length | |
padded_window = np.pad(window, (0, win_samples - len(window))) | |
samples.append(padded_window) | |
break | |
logger.info(f"Generated {len(samples)} MERT windows covering entire track") | |
# Process each sample and collect embeddings | |
embeddings = [] | |
for i, sample in enumerate(samples): | |
sample_length = win_samples | |
if len(sample) < sample_length: | |
# Pad short samples with zeros | |
sample = np.pad(sample, (0, sample_length - len(sample))) | |
logger.info(f"Processing MERT sample {i+1}/{len(samples)}") | |
# Process with MERT processor | |
inputs = self.mert_processor( | |
sample, | |
sampling_rate=24000, | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.mert_model(**inputs, output_hidden_states=True) | |
# Get all hidden states (13 layers for MERT-v1-95M) | |
all_hidden_states = outputs.hidden_states # Shape: [13, batch_size, time_steps, 768] | |
# Time reduction: average across time dimension | |
# Shape: [13, batch_size, 768] | |
time_reduced_states = torch.stack([layer.mean(dim=1) for layer in all_hidden_states]) | |
# Take the mean across all layers to get a single 768-dim representation | |
# This follows the typical approach for utterance-level representations | |
window_embedding = time_reduced_states.mean(dim=0).squeeze() # Shape: [768] | |
embeddings.append(window_embedding.cpu().numpy().flatten()) | |
# Average the raw embeddings from all samples (like CLAP) | |
final_embedding = np.mean(embeddings, axis=0) | |
# Ensure finite values | |
mert_features = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
# L2 normalize | |
norm = np.linalg.norm(mert_features) | |
if norm > 0: | |
mert_features = mert_features / norm | |
else: | |
# If norm is 0, create a small non-zero vector | |
mert_features = np.ones(768, dtype=np.float32) * 0.001 | |
mert_features = mert_features / np.linalg.norm(mert_features) | |
# Final check for finite values | |
mert_features = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0) | |
logger.info(f"Extracted MERT features: {len(mert_features)} dimensions (norm: {np.linalg.norm(mert_features):.6f})") | |
return mert_features | |
except Exception as e: | |
logger.error(f"Error extracting MERT features: {str(e)}") | |
# Return normalized zero vector if extraction fails | |
zero_features = np.ones(768, dtype=np.float32) * 0.001 | |
return zero_features / np.linalg.norm(zero_features) | |
def fuse_embeddings(self, clap_embedding: np.ndarray, pitch_features: np.ndarray = None, mert_features: np.ndarray = None) -> np.ndarray: | |
"""Fuse CLAP, pitch-aware, and/or MERT features based on fusion mode""" | |
if FUSION_MODE == "VECTOR_CONCAT": | |
# Normalize before and after concatenation (critical for proper similarity computation) | |
embeddings_to_fuse = [] | |
# Step 1: Normalize each embedding individually (clap_u = normalize(clap)) | |
# Always include CLAP | |
clap_u = clap_embedding / np.linalg.norm(clap_embedding) | |
embeddings_to_fuse.append(clap_u) | |
# Add MERT if available | |
if mert_features is not None: | |
mert_u = mert_features / np.linalg.norm(mert_features) | |
embeddings_to_fuse.append(mert_u) | |
# Add pitch features if available | |
if pitch_features is not None: | |
pitch_u = pitch_features / np.linalg.norm(pitch_features) | |
embeddings_to_fuse.append(pitch_u) | |
# Step 2: Concatenate normalized embeddings (fused = cat([clap_u, mert_u])) | |
fused = np.concatenate(embeddings_to_fuse) | |
# Step 3: Normalize the concatenated result (fused = normalize(fused)) | |
# Critical: Without this, similarity is inflated by √2 (0.83 → 0.93) | |
fused_norm = np.linalg.norm(fused) | |
if fused_norm > 0: | |
fused = fused / fused_norm | |
else: | |
logger.warning("Zero norm in fused embedding - creating fallback vector") | |
fused = np.ones_like(fused) * 0.001 | |
fused = fused / np.linalg.norm(fused) | |
# Verify norm is 1.0 (assert abs(fused.norm() - 1.0) < 1e-6) | |
final_norm = np.linalg.norm(fused) | |
if abs(final_norm - 1.0) > 1e-6: | |
logger.warning(f"Fused embedding norm is {final_norm:.8f}, not 1.0 - normalization issue!") | |
# Ensure finite values for JSON serialization | |
fused = np.nan_to_num(fused, nan=0.0, posinf=0.0, neginf=0.0) | |
return fused | |
else: | |
# SCORE_FUSION: return embeddings separately for weighted similarity calculation | |
result = {} | |
result["clap"] = np.nan_to_num(clap_embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
if mert_features is not None: | |
result["mert"] = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0) | |
if pitch_features is not None: | |
result["pitch"] = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) | |
return result | |
def is_supported_format(self, image_url: str) -> bool: | |
"""Check if image format is supported by PIL/CLIP""" | |
unsupported_extensions = ['.avif', '.heic', '.heif'] | |
url_lower = image_url.lower() | |
return not any(url_lower.endswith(ext) for ext in unsupported_extensions) | |
def detect_image_format(self, content: bytes) -> str: | |
"""Detect actual image format from content""" | |
try: | |
# Check for AVIF signature | |
if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]: | |
return 'AVIF' | |
# Check for HEIC signature | |
elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]): | |
return 'HEIC' | |
# Check for WebP | |
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]: | |
return 'WebP' | |
# Check for PNG | |
elif content.startswith(b'\x89PNG\r\n\x1a\n'): | |
return 'PNG' | |
# Check for JPEG | |
elif content.startswith(b'\xff\xd8\xff'): | |
return 'JPEG' | |
# Check for GIF | |
elif content.startswith((b'GIF87a', b'GIF89a')): | |
return 'GIF' | |
else: | |
return 'Unknown' | |
except: | |
return 'Unknown' | |
def encode_image(self, image_url: str) -> list: | |
try: | |
logger.info(f"Processing image: {image_url}") | |
# Quick URL-based format check first | |
if not self.is_supported_format(image_url): | |
logger.warning(f"Unsupported format detected from URL: {image_url}") | |
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)") | |
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'}) | |
response.raise_for_status() | |
# Detect actual format from content | |
image_format = self.detect_image_format(response.content) | |
logger.info(f"Detected image format: {image_format}") | |
if image_format in ['AVIF', 'HEIC']: | |
logger.warning(f"Unsupported format detected: {image_format} for {image_url}") | |
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}") | |
try: | |
image = Image.open(io.BytesIO(response.content)) | |
except Exception as e: | |
logger.error(f"PIL cannot open image {image_url}: {str(e)}") | |
if "cannot identify image file" in str(e).lower(): | |
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format") | |
raise | |
if image.mode != 'RGB': | |
logger.info(f"Converting image from {image.mode} to RGB") | |
image = image.convert('RGB') | |
# Resize image if too large to avoid memory issues | |
max_size = 224 # CLIP's expected input size | |
if max(image.size) > max_size: | |
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
# Use the working method directly (Method 3) to avoid fallback overhead | |
inputs = self.clip_processor( | |
images=[image], | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
image_features = self.clip_model.get_image_features(**inputs) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
# Ensure safe values for JSON serialization | |
embedding = image_features.cpu().numpy().flatten() | |
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
return embedding.tolist() | |
except Exception as e: | |
logger.error(f"Error encoding image {image_url}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") | |
def encode_text(self, text: str) -> list: | |
try: | |
logger.info(f"Processing text: {text[:50]}...") | |
inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
text_features = self.clip_model.get_text_features(**inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# Ensure safe values for JSON serialization | |
embedding = text_features.cpu().numpy().flatten() | |
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
return embedding.tolist() | |
except Exception as e: | |
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") | |
def encode_audio(self, audio_url: str) -> list: | |
try: | |
logger.info(f"Processing audio: {audio_url}") | |
# Load CLAP model on demand | |
self._load_clap_model() | |
# Pitch fusion is enabled if pitch-aware features are available | |
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE: | |
logger.info("Pitch fusion enabled with librosa features") | |
# Download audio file | |
response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'}) | |
response.raise_for_status() | |
# Save to temporary file | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: | |
tmp_file.write(response.content) | |
tmp_path = tmp_file.name | |
try: | |
# Load audio using soundfile first, then resample with librosa if needed | |
# This avoids the caching issues with librosa.load | |
logger.info("Loading audio with soundfile...") | |
audio_array, original_sr = sf.read(tmp_path) | |
# Convert to mono if needed | |
if len(audio_array.shape) > 1: | |
logger.info("Converting stereo to mono") | |
audio_array = audio_array.mean(axis=1) | |
else: | |
logger.info("Audio is already mono") | |
# Resample to 48kHz for CLAP processing | |
if original_sr != 48000: | |
logger.info(f"Resampling from {original_sr}Hz to 48kHz...") | |
# Use scipy for resampling to avoid librosa caching issues | |
from scipy.signal import resample | |
target_length = int(len(audio_array) * 48000 / original_sr) | |
audio_array_48k = resample(audio_array, target_length) | |
else: | |
audio_array_48k = audio_array | |
# Critical: Convert to float32 for CLAP (community checkpoints expect float32) | |
audio_array_48k = audio_array_48k.astype(np.float32, copy=False) | |
total_duration = len(audio_array_48k) / 48000 | |
logger.info(f"Audio loaded: {len(audio_array_48k)} samples at 48kHz ({total_duration:.1f} seconds)") | |
# Check audio duration limit and truncate both arrays consistently | |
if total_duration > MAX_AUDIO_DURATION_SEC: | |
logger.warning(f"Audio duration {total_duration:.1f}s exceeds limit {MAX_AUDIO_DURATION_SEC}s, truncating...") | |
max_samples_48k = MAX_AUDIO_DURATION_SEC * 48000 | |
max_samples_orig = MAX_AUDIO_DURATION_SEC * original_sr | |
# Truncate both arrays to keep MERT and Pitch processing within limits | |
audio_array_48k = audio_array_48k[:max_samples_48k] | |
audio_array = audio_array[:max_samples_orig] | |
total_duration = MAX_AUDIO_DURATION_SEC | |
logger.info(f"Truncated both arrays to {total_duration:.1f} seconds (48kHz: {len(audio_array_48k)} samples, {original_sr}Hz: {len(audio_array)} samples)") | |
# Process with CLAP (10s windows, 5s hops) | |
clap_embedding = self._process_clap_embeddings(audio_array_48k, total_duration) | |
# Initialize additional features | |
pitch_features = None | |
mert_features = None | |
# Process with MERT features if enabled | |
if ENABLE_MERT_FUSION and MERT_AVAILABLE: | |
try: | |
logger.info("Processing with MERT features for fusion...") | |
mert_features = self.extract_mert_features(audio_array, original_sr) | |
except Exception as e: | |
logger.error(f"MERT feature processing failed: {str(e)}") | |
mert_features = None | |
# Process with pitch-aware features if enabled (can run alongside MERT) | |
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE: | |
try: | |
logger.info("Processing with pitch-aware features for fusion...") | |
pitch_features = self._process_pitch_features(audio_array, original_sr, total_duration) | |
except Exception as e: | |
logger.error(f"Pitch feature processing failed: {str(e)}") | |
pitch_features = None | |
# Handle fusion based on what features are available | |
if mert_features is not None or pitch_features is not None: | |
if FUSION_MODE == "VECTOR_CONCAT": | |
final_embedding = self.fuse_embeddings(clap_embedding, pitch_features, mert_features) | |
if mert_features is not None: | |
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + MERT 768)") | |
elif pitch_features is not None: | |
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + Pitch 85)") | |
return final_embedding.tolist() | |
else: | |
# SCORE_FUSION: return embeddings separately | |
logger.info("Score fusion mode: returning separate embeddings") | |
result = self.fuse_embeddings(clap_embedding, pitch_features, mert_features) | |
result["fusion_mode"] = "SCORE_FUSION" | |
result["fusion_alpha"] = FUSION_ALPHA | |
# Convert to lists for JSON serialization | |
for key in result: | |
if isinstance(result[key], np.ndarray): | |
result[key] = result[key].tolist() | |
return result | |
else: | |
# CLAP-only processing (backwards compatible) | |
logger.info("Using CLAP-only processing") | |
return clap_embedding.tolist() | |
finally: | |
# Clean up temp file | |
if os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |
except Exception as e: | |
logger.error(f"Error encoding audio {audio_url}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}") | |
def _process_clap_embeddings(self, audio_array: np.ndarray, total_duration: float) -> np.ndarray: | |
"""Process audio with CLAP using 10s windows and 5s hops""" | |
# Multi-window sampling approach for better discrimination | |
hop_seconds = 5 # Move window every 5 seconds | |
win_seconds = 10 # Each window is 10 seconds | |
hop_samples = hop_seconds * 48000 | |
win_samples = win_seconds * 48000 | |
samples = [] | |
if total_duration <= win_seconds: | |
# Short audio: use the entire thing | |
logger.info("Short audio: using entire file for CLAP") | |
samples = [audio_array] | |
else: | |
# Multi-window sampling: overlapping windows across entire track | |
logger.info("Multi-window sampling for CLAP: extracting overlapping windows") | |
max_offset = int(total_duration - win_seconds) + 1 | |
for t in range(0, max_offset, hop_seconds): | |
start_sample = t * 48000 | |
end_sample = start_sample + win_samples | |
# Ensure we don't go beyond the audio length | |
if end_sample <= len(audio_array): | |
window = audio_array[start_sample:end_sample] | |
samples.append(window) | |
else: | |
# Last window: take what's available | |
window = audio_array[start_sample:] | |
if len(window) >= win_samples // 2: # At least 5 seconds | |
# Pad to full window length | |
padded_window = np.pad(window, (0, win_samples - len(window))) | |
samples.append(padded_window) | |
break | |
logger.info(f"Generated {len(samples)} CLAP windows covering entire track") | |
# Process each sample and collect embeddings | |
embeddings = [] | |
for i, sample in enumerate(samples): | |
sample_length = win_samples | |
if len(sample) < sample_length: | |
# Pad short samples with zeros | |
sample = np.pad(sample, (0, sample_length - len(sample))) | |
logger.info(f"Processing CLAP sample {i+1}/{len(samples)}") | |
# Process with CLAP | |
inputs = self.clap_processor( | |
audios=sample, | |
sampling_rate=48000, | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
audio_features = self.clap_model.get_audio_features(**inputs) | |
window_vec = audio_features.squeeze(0) # 512-D, no L2 | |
embeddings.append(window_vec.cpu().numpy().flatten()) | |
# Average the raw embeddings from all samples | |
final_embedding = np.mean(embeddings, axis=0) | |
# Ensure finite values before normalization | |
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
# Ensure proper L2 normalization for cosine similarity | |
norm = np.linalg.norm(final_embedding) | |
if norm > 0: | |
final_embedding = final_embedding / norm | |
else: | |
logger.warning("Zero norm CLAP embedding detected") | |
# Create a small normalized vector instead of zero | |
final_embedding = np.ones_like(final_embedding) * 0.001 | |
final_embedding = final_embedding / np.linalg.norm(final_embedding) | |
# Final safety check for JSON serialization | |
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) | |
# Verify normalization | |
final_norm = np.linalg.norm(final_embedding) | |
logger.info(f"Final CLAP embedding norm: {final_norm:.6f} (should be ~1.0)") | |
return final_embedding | |
def _process_pitch_features(self, audio_array: np.ndarray, original_sr: int, total_duration: float) -> np.ndarray: | |
"""Process audio with pitch-aware features using 5s windows and 2s hops""" | |
# Pitch feature processing parameters | |
hop_seconds = 2 # Move window every 2 seconds | |
win_seconds = 5 # Each window is 5 seconds | |
hop_samples = hop_seconds * original_sr | |
win_samples = win_seconds * original_sr | |
samples = [] | |
if total_duration <= win_seconds: | |
# Short audio: use the entire thing | |
logger.info("Short audio: using entire file for pitch features") | |
samples = [audio_array] | |
else: | |
# Multi-window sampling: overlapping windows across entire track | |
logger.info("Multi-window sampling for pitch features: extracting overlapping windows") | |
max_offset = int(total_duration - win_seconds) + 1 | |
for t in range(0, max_offset, hop_seconds): | |
start_sample = t * original_sr | |
end_sample = start_sample + win_samples | |
# Ensure we don't go beyond the audio length | |
if end_sample <= len(audio_array): | |
window = audio_array[start_sample:end_sample] | |
samples.append(window) | |
else: | |
# Last window: take what's available | |
window = audio_array[start_sample:] | |
if len(window) >= win_samples // 2: # At least 2.5 seconds | |
# Pad to full window length | |
padded_window = np.pad(window, (0, win_samples - len(window))) | |
samples.append(padded_window) | |
break | |
logger.info(f"Generated {len(samples)} pitch feature windows covering entire track") | |
# Process each sample and collect features | |
feature_vectors = [] | |
for i, sample in enumerate(samples): | |
sample_length = win_samples | |
if len(sample) < sample_length: | |
# Pad short samples with zeros | |
sample = np.pad(sample, (0, sample_length - len(sample))) | |
logger.info(f"Processing pitch features sample {i+1}/{len(samples)}") | |
# Extract pitch features | |
pitch_features = self.extract_pitch_features(sample, original_sr) | |
feature_vectors.append(pitch_features) | |
# Average the raw feature vectors from all samples | |
final_features = np.mean(feature_vectors, axis=0) | |
# Ensure finite values before normalization | |
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0) | |
# Ensure proper L2 normalization for cosine similarity | |
norm = np.linalg.norm(final_features) | |
if norm > 0: | |
final_features = final_features / norm | |
else: | |
logger.warning("Zero norm pitch features detected") | |
# Create a small normalized vector instead of zero | |
final_features = np.ones_like(final_features) * 0.001 | |
final_features = final_features / np.linalg.norm(final_features) | |
# Final safety check for JSON serialization | |
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0) | |
# Verify normalization | |
final_norm = np.linalg.norm(final_features) | |
logger.info(f"Final pitch features norm: {final_norm:.6f} (should be ~1.0)") | |
return final_features | |
# Initialize service with error handling | |
logger.info("Initializing CLIP service...") | |
try: | |
clip_service = CLIPService() | |
logger.info("CLIP service initialized successfully!") | |
except Exception as e: | |
logger.error(f"Failed to initialize CLIP service: {str(e)}") | |
logger.error(f"Error details: {type(e).__name__}: {str(e)}") | |
# For now, we'll let the app start but service calls will fail gracefully | |
clip_service = None | |
class ImageRequest(BaseModel): | |
image_url: str | |
class TextRequest(BaseModel): | |
text: str | |
class AudioRequest(BaseModel): | |
audio_url: str | |
async def root(): | |
return { | |
"message": "CLIP Service API", | |
"version": "1.0.0", | |
"model": "clip-vit-large-patch14", | |
"endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"], | |
"status": "ready" if clip_service else "error" | |
} | |
async def encode_image(request: ImageRequest): | |
if not clip_service: | |
raise HTTPException(status_code=503, detail="CLIP service not available") | |
embedding = clip_service.encode_image(request.image_url) | |
safe_embedding = sanitize_for_json(embedding) | |
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} | |
async def encode_text(request: TextRequest): | |
if not clip_service: | |
raise HTTPException(status_code=503, detail="CLIP service not available") | |
embedding = clip_service.encode_text(request.text) | |
safe_embedding = sanitize_for_json(embedding) | |
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} | |
async def encode_audio(request: AudioRequest): | |
if not clip_service: | |
raise HTTPException(status_code=503, detail="CLAP service not available") | |
if not CLAP_AVAILABLE: | |
raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version") | |
embedding = clip_service.encode_audio(request.audio_url) | |
# Handle both single embedding and fusion mode results | |
if isinstance(embedding, dict): | |
# Score fusion mode - sanitize all embeddings | |
safe_embedding = {} | |
dimensions = {} | |
for key, value in embedding.items(): | |
if key in ["clap", "pitch", "mert"] and isinstance(value, list): | |
safe_embedding[key] = sanitize_for_json(value) | |
dimensions[key] = len(safe_embedding[key]) | |
else: | |
safe_embedding[key] = value | |
return {"embedding": safe_embedding, "dimensions": dimensions} | |
else: | |
# Single embedding (CLAP-only or concatenated) | |
safe_embedding = sanitize_for_json(embedding) | |
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} | |
async def health_check(): | |
if not clip_service: | |
return { | |
"status": "unhealthy", | |
"model": "clip-vit-large-patch14", | |
"error": "Service failed to initialize" | |
} | |
health_info = { | |
"status": "healthy", | |
"models": { | |
"clip": "clip-vit-large-patch14", | |
"clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available", | |
"mert": f"MERT-v1-95M (lazy loaded, method: {MERT_METHOD})" if MERT_AVAILABLE else "not available" | |
}, | |
"device": clip_service.device, | |
"service": "ready", | |
"cache_dir": cache_dir | |
} | |
# Add pitch-aware information | |
if PITCH_AWARE_AVAILABLE: | |
health_info["models"]["pitch_aware"] = f"librosa features ({PITCH_METHOD})" | |
# Add fusion information | |
fusion_enabled = ENABLE_PITCH_FUSION or ENABLE_MERT_FUSION | |
if fusion_enabled: | |
health_info["fusion"] = { | |
"enabled": True, | |
"mode": FUSION_MODE, | |
"pitch_fusion_enabled": ENABLE_PITCH_FUSION, | |
"mert_fusion_enabled": ENABLE_MERT_FUSION, | |
"pitch_aware_available": PITCH_AWARE_AVAILABLE, | |
"mert_available": MERT_AVAILABLE | |
} | |
if FUSION_MODE == "SCORE_FUSION": | |
health_info["fusion"]["alpha"] = FUSION_ALPHA | |
else: | |
health_info["fusion"] = { | |
"enabled": False, | |
"mode": "CLAP_ONLY" | |
} | |
return health_info | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 | |
uvicorn.run(app, host="0.0.0.0", port=port) |