from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline import tempfile import os import uvicorn import numpy as np import logging from datetime import datetime import torch from contextlib import asynccontextmanager import subprocess # Configurar cache os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' os.environ['HF_HOME'] = '/tmp/huggingface' os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache' os.environ['NUMBA_DISABLE_JIT'] = '1' # Configurar logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Variables globales classifier = None async def load_model(): """Cargar modelo Janiopi con configuración específica""" global classifier try: logger.info("Cargando modelo...") # Crear directorios de cache os.makedirs('/tmp/transformers_cache', exist_ok=True) os.makedirs('/tmp/huggingface', exist_ok=True) os.makedirs('/tmp/numba_cache', exist_ok=True) # MODELO JANIOPI con configuración específica model_name = "alonb19/EF-instruments-v1" logger.info(f"Modelo: {model_name}") # Configurar pipeline con padding y truncación como en el código original classifier = pipeline( "audio-classification", model=model_name, device=-1, feature_extractor_kwargs={ "padding": True, "truncation": True, "max_length": 240000, # 15 segundos a 16kHz "return_tensors": "pt" }, return_all_scores=True ) logger.info("✅ Modelo cargado exitosamente") except Exception as e: logger.error(f"❌ Error cargando modelo: {e}") classifier = None async def cleanup_model(): """Limpiar recursos""" global classifier classifier = None logger.info("Recursos liberados") @asynccontextmanager async def lifespan(app: FastAPI): await load_model() yield await cleanup_model() app = FastAPI( title="Musical Instrument Detection API", description="API para detectar instrumentos musicales con modelo Janiopi", version="7.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def convert_audio_with_ffmpeg(input_path, output_path): """Convertir audio usando ffmpeg""" try: cmd = [ 'ffmpeg', '-y', '-i', input_path, '-ar', '16000', '-ac', '1', '-f', 'wav', output_path ] result = subprocess.run( cmd, capture_output=True, text=True, timeout=30 ) return result.returncode == 0 except Exception as e: logger.error(f"Error con ffmpeg: {e}") return False def load_audio_robust(file_path): """Cargar audio de forma robusta""" try: converted_path = file_path.replace('.wav', '_converted.wav') if convert_audio_with_ffmpeg(file_path, converted_path): try: import soundfile as sf audio_data, sample_rate = sf.read(converted_path) if os.path.exists(converted_path): os.unlink(converted_path) if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) return audio_data.astype(np.float32), sample_rate except Exception as e: if os.path.exists(converted_path): os.unlink(converted_path) raise e raise Exception("No se pudo procesar el audio") except Exception as e: logger.error(f"Error cargando audio: {e}") raise def analyze_audio_segments(audio_data, sample_rate): """Analizar audio en múltiples segmentos para mejor detección""" duration = len(audio_data) / sample_rate segment_duration = 8 # 8 segundos por segmento all_results = [] if duration <= segment_duration: # Audio corto, analizar completo results = classifier(audio_data) all_results.extend(results) else: # Audio largo, analizar múltiples segmentos num_segments = min(4, int(duration / segment_duration)) # Máximo 4 segmentos for i in range(num_segments): start_sample = i * segment_duration * sample_rate end_sample = min((i + 1) * segment_duration * sample_rate, len(audio_data)) segment = audio_data[start_sample:end_sample] # Normalizar segmento if np.max(np.abs(segment)) > 0: segment = segment / np.max(np.abs(segment)) results = classifier(segment) all_results.extend(results) logger.info(f"📊 Segmento {i+1}/{num_segments} analizado") return all_results def combine_instrument_results(all_results): """Combinar resultados de múltiples segmentos""" # Agrupar por instrumento y tomar la máxima confianza instrument_scores = {} for result in all_results: label = result['label'] score = result['score'] if label in instrument_scores: instrument_scores[label] = max(instrument_scores[label], score) else: instrument_scores[label] = score # Convertir a lista instruments_detected = [] for label, score in instrument_scores.items(): if score > 0.05: # Umbral mínimo 5% instruments_detected.append({ "label": label, "score": round(score, 4), "percentage": round(score * 100, 2) }) # Ordenar por confianza descendente (igual que el código original) instruments_detected.sort(key=lambda x: x['score'], reverse=True) return instruments_detected @app.get("/") async def root(): return { "message": "Musical Instrument Detection API", "status": "online", "version": "7.0.0", "model": "Janiopi/detector_de_instrumentos_v1", "supported_instruments": ["Guitar", "Piano", "Drum"], "max_duration_seconds": 15, "endpoints": { "health": "/health", "detect": "/detect", "docs": "/docs" } } @app.get("/health") async def health_check(): """Verificar estado del servicio - Igual que el código original""" return { "status": "online" if classifier is not None else "offline", "model_loaded": classifier is not None, "message": "API funcionando correctamente" if classifier is not None else "Modelo no disponible", "model_info": "Janiopi/detector_de_instrumentos_v1", "supported_instruments": ["Guitar", "Piano", "Drum"], "max_duration_seconds": 15, "sample_rate": 16000, "timestamp": datetime.now().isoformat() } @app.post("/detect") async def detect_instrument(audio: UploadFile = File(...)): """ Detectar instrumentos musicales - Manteniendo estructura del código original """ start_time = datetime.now() try: if classifier is None: raise HTTPException( status_code=503, detail="Modelo no disponible. Intenta más tarde." ) logger.info(f"📁 Procesando: {audio.filename} ({audio.content_type})") # Leer contenido content = await audio.read() logger.info(f"📏 Tamaño: {len(content)} bytes") if len(content) > 10 * 1024 * 1024: # 10MB máximo raise HTTPException(status_code=413, detail="Archivo muy grande") # Crear archivo temporal with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: temp_file.write(content) temp_path = temp_file.name try: logger.info("🎵 Cargando audio...") # Cargar audio audio_data, sample_rate = load_audio_robust(temp_path) logger.info(f"🔊 Audio cargado: {len(audio_data)} samples a {sample_rate}Hz") logger.info(f"⏱️ Duración: {len(audio_data)/sample_rate:.2f} segundos") # Verificar duración mínima - igual que código original if len(audio_data) < 1600: # Menos de 0.1 segundos raise ValueError("Audio demasiado corto (mínimo 0.1 segundos)") # Truncar a máximo 15 segundos - igual que código original max_samples = 15 * 16000 if len(audio_data) > max_samples: audio_data = audio_data[:max_samples] logger.info("🔄 Audio truncado a 15 segundos") # Asegurar formato correcto audio_data = np.array(audio_data, dtype=np.float32) # Normalizar if np.max(np.abs(audio_data)) > 0: audio_data = audio_data / np.max(np.abs(audio_data)) logger.info("🤖 Ejecutando modelo...") # Analizar (usando segmentos para mejor precisión) all_results = analyze_audio_segments(audio_data, sample_rate) # Combinar resultados formatted_results = combine_instrument_results(all_results) logger.info(f"🎯 Resultados: {formatted_results}") processing_time = (datetime.now() - start_time).total_seconds() # Respuesta con MISMA ESTRUCTURA que el código original response = { "success": True, "results": formatted_results, # Manteniendo nombre "results" "filename": audio.filename, "audio_info": { "samples": len(audio_data), "sample_rate": sample_rate, "duration_seconds": round(len(audio_data) / sample_rate, 2), "processed_size_bytes": len(content) }, "processing_time": round(processing_time, 3) } logger.info(f"✅ Completado en {processing_time:.3f}s") if formatted_results: logger.info(f"🎯 Principal: {formatted_results[0]['label']} ({formatted_results[0]['percentage']:.1f}%)") return response finally: # Limpiar archivo temporal if os.path.exists(temp_path): os.unlink(temp_path) except HTTPException: raise except Exception as e: logger.error(f"❌ Error inesperado: {e}") # Mensajes de error específicos - igual que código original error_msg = str(e) if "Unable to create tensor" in error_msg: detail = "Error de formato de audio. Intenta con un archivo WAV de mejor calidad." elif "too short" in error_msg.lower(): detail = "Audio demasiado corto. Graba al menos 1 segundo." elif "padding" in error_msg: detail = "Error de procesamiento de audio. Intenta con un archivo diferente." else: detail = f"Error procesando audio: {error_msg}" raise HTTPException(status_code=500, detail=detail) @app.get("/test") async def test_endpoint(): """Endpoint de prueba - igual que código original""" return { "message": "API funcionando", "timestamp": datetime.now().isoformat(), "test": "ok" } if __name__ == "__main__": logger.info("🚀 Iniciando Musical Instrument Detection API con modelo Janiopi...") uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )