|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from transformers import pipeline |
|
import tempfile |
|
import os |
|
import uvicorn |
|
import numpy as np |
|
import logging |
|
from datetime import datetime |
|
import torch |
|
from contextlib import asynccontextmanager |
|
import subprocess |
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' |
|
os.environ['HF_HOME'] = '/tmp/huggingface' |
|
os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache' |
|
os.environ['NUMBA_DISABLE_JIT'] = '1' |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
classifier = None |
|
|
|
async def load_model(): |
|
"""Cargar modelo Janiopi con configuración específica""" |
|
global classifier |
|
try: |
|
logger.info("Cargando modelo...") |
|
|
|
|
|
os.makedirs('/tmp/transformers_cache', exist_ok=True) |
|
os.makedirs('/tmp/huggingface', exist_ok=True) |
|
os.makedirs('/tmp/numba_cache', exist_ok=True) |
|
|
|
|
|
model_name = "alonb19/EF-instruments-v1" |
|
|
|
logger.info(f"Modelo: {model_name}") |
|
|
|
|
|
classifier = pipeline( |
|
"audio-classification", |
|
model=model_name, |
|
device=-1, |
|
feature_extractor_kwargs={ |
|
"padding": True, |
|
"truncation": True, |
|
"max_length": 240000, |
|
"return_tensors": "pt" |
|
}, |
|
return_all_scores=True |
|
) |
|
|
|
logger.info("✅ Modelo cargado exitosamente") |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Error cargando modelo: {e}") |
|
classifier = None |
|
|
|
async def cleanup_model(): |
|
"""Limpiar recursos""" |
|
global classifier |
|
classifier = None |
|
logger.info("Recursos liberados") |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
await load_model() |
|
yield |
|
await cleanup_model() |
|
|
|
app = FastAPI( |
|
title="Musical Instrument Detection API", |
|
description="API para detectar instrumentos musicales con modelo Janiopi", |
|
version="7.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
def convert_audio_with_ffmpeg(input_path, output_path): |
|
"""Convertir audio usando ffmpeg""" |
|
try: |
|
cmd = [ |
|
'ffmpeg', '-y', '-i', input_path, |
|
'-ar', '16000', '-ac', '1', |
|
'-f', 'wav', output_path |
|
] |
|
|
|
result = subprocess.run( |
|
cmd, |
|
capture_output=True, |
|
text=True, |
|
timeout=30 |
|
) |
|
|
|
return result.returncode == 0 |
|
except Exception as e: |
|
logger.error(f"Error con ffmpeg: {e}") |
|
return False |
|
|
|
def load_audio_robust(file_path): |
|
"""Cargar audio de forma robusta""" |
|
try: |
|
converted_path = file_path.replace('.wav', '_converted.wav') |
|
|
|
if convert_audio_with_ffmpeg(file_path, converted_path): |
|
try: |
|
import soundfile as sf |
|
audio_data, sample_rate = sf.read(converted_path) |
|
|
|
if os.path.exists(converted_path): |
|
os.unlink(converted_path) |
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
return audio_data.astype(np.float32), sample_rate |
|
|
|
except Exception as e: |
|
if os.path.exists(converted_path): |
|
os.unlink(converted_path) |
|
raise e |
|
|
|
raise Exception("No se pudo procesar el audio") |
|
|
|
except Exception as e: |
|
logger.error(f"Error cargando audio: {e}") |
|
raise |
|
|
|
def analyze_audio_segments(audio_data, sample_rate): |
|
"""Analizar audio en múltiples segmentos para mejor detección""" |
|
duration = len(audio_data) / sample_rate |
|
segment_duration = 8 |
|
|
|
all_results = [] |
|
|
|
if duration <= segment_duration: |
|
|
|
results = classifier(audio_data) |
|
all_results.extend(results) |
|
else: |
|
|
|
num_segments = min(4, int(duration / segment_duration)) |
|
|
|
for i in range(num_segments): |
|
start_sample = i * segment_duration * sample_rate |
|
end_sample = min((i + 1) * segment_duration * sample_rate, len(audio_data)) |
|
|
|
segment = audio_data[start_sample:end_sample] |
|
|
|
|
|
if np.max(np.abs(segment)) > 0: |
|
segment = segment / np.max(np.abs(segment)) |
|
|
|
results = classifier(segment) |
|
all_results.extend(results) |
|
|
|
logger.info(f"📊 Segmento {i+1}/{num_segments} analizado") |
|
|
|
return all_results |
|
|
|
def combine_instrument_results(all_results): |
|
"""Combinar resultados de múltiples segmentos""" |
|
|
|
instrument_scores = {} |
|
|
|
for result in all_results: |
|
label = result['label'] |
|
score = result['score'] |
|
|
|
if label in instrument_scores: |
|
instrument_scores[label] = max(instrument_scores[label], score) |
|
else: |
|
instrument_scores[label] = score |
|
|
|
|
|
instruments_detected = [] |
|
for label, score in instrument_scores.items(): |
|
if score > 0.05: |
|
instruments_detected.append({ |
|
"label": label, |
|
"score": round(score, 4), |
|
"percentage": round(score * 100, 2) |
|
}) |
|
|
|
|
|
instruments_detected.sort(key=lambda x: x['score'], reverse=True) |
|
|
|
return instruments_detected |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"message": "Musical Instrument Detection API", |
|
"status": "online", |
|
"version": "7.0.0", |
|
"model": "Janiopi/detector_de_instrumentos_v1", |
|
"supported_instruments": ["Guitar", "Piano", "Drum"], |
|
"max_duration_seconds": 15, |
|
"endpoints": { |
|
"health": "/health", |
|
"detect": "/detect", |
|
"docs": "/docs" |
|
} |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Verificar estado del servicio - Igual que el código original""" |
|
return { |
|
"status": "online" if classifier is not None else "offline", |
|
"model_loaded": classifier is not None, |
|
"message": "API funcionando correctamente" if classifier is not None else "Modelo no disponible", |
|
"model_info": "Janiopi/detector_de_instrumentos_v1", |
|
"supported_instruments": ["Guitar", "Piano", "Drum"], |
|
"max_duration_seconds": 15, |
|
"sample_rate": 16000, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
@app.post("/detect") |
|
async def detect_instrument(audio: UploadFile = File(...)): |
|
""" |
|
Detectar instrumentos musicales - Manteniendo estructura del código original |
|
""" |
|
start_time = datetime.now() |
|
|
|
try: |
|
if classifier is None: |
|
raise HTTPException( |
|
status_code=503, |
|
detail="Modelo no disponible. Intenta más tarde." |
|
) |
|
|
|
logger.info(f"📁 Procesando: {audio.filename} ({audio.content_type})") |
|
|
|
|
|
content = await audio.read() |
|
logger.info(f"📏 Tamaño: {len(content)} bytes") |
|
|
|
if len(content) > 10 * 1024 * 1024: |
|
raise HTTPException(status_code=413, detail="Archivo muy grande") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: |
|
temp_file.write(content) |
|
temp_path = temp_file.name |
|
|
|
try: |
|
logger.info("🎵 Cargando audio...") |
|
|
|
|
|
audio_data, sample_rate = load_audio_robust(temp_path) |
|
|
|
logger.info(f"🔊 Audio cargado: {len(audio_data)} samples a {sample_rate}Hz") |
|
logger.info(f"⏱️ Duración: {len(audio_data)/sample_rate:.2f} segundos") |
|
|
|
|
|
if len(audio_data) < 1600: |
|
raise ValueError("Audio demasiado corto (mínimo 0.1 segundos)") |
|
|
|
|
|
max_samples = 15 * 16000 |
|
if len(audio_data) > max_samples: |
|
audio_data = audio_data[:max_samples] |
|
logger.info("🔄 Audio truncado a 15 segundos") |
|
|
|
|
|
audio_data = np.array(audio_data, dtype=np.float32) |
|
|
|
|
|
if np.max(np.abs(audio_data)) > 0: |
|
audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
|
logger.info("🤖 Ejecutando modelo...") |
|
|
|
|
|
all_results = analyze_audio_segments(audio_data, sample_rate) |
|
|
|
|
|
formatted_results = combine_instrument_results(all_results) |
|
|
|
logger.info(f"🎯 Resultados: {formatted_results}") |
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
response = { |
|
"success": True, |
|
"results": formatted_results, |
|
"filename": audio.filename, |
|
"audio_info": { |
|
"samples": len(audio_data), |
|
"sample_rate": sample_rate, |
|
"duration_seconds": round(len(audio_data) / sample_rate, 2), |
|
"processed_size_bytes": len(content) |
|
}, |
|
"processing_time": round(processing_time, 3) |
|
} |
|
|
|
logger.info(f"✅ Completado en {processing_time:.3f}s") |
|
if formatted_results: |
|
logger.info(f"🎯 Principal: {formatted_results[0]['label']} ({formatted_results[0]['percentage']:.1f}%)") |
|
|
|
return response |
|
|
|
finally: |
|
|
|
if os.path.exists(temp_path): |
|
os.unlink(temp_path) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"❌ Error inesperado: {e}") |
|
|
|
|
|
error_msg = str(e) |
|
if "Unable to create tensor" in error_msg: |
|
detail = "Error de formato de audio. Intenta con un archivo WAV de mejor calidad." |
|
elif "too short" in error_msg.lower(): |
|
detail = "Audio demasiado corto. Graba al menos 1 segundo." |
|
elif "padding" in error_msg: |
|
detail = "Error de procesamiento de audio. Intenta con un archivo diferente." |
|
else: |
|
detail = f"Error procesando audio: {error_msg}" |
|
|
|
raise HTTPException(status_code=500, detail=detail) |
|
|
|
@app.get("/test") |
|
async def test_endpoint(): |
|
"""Endpoint de prueba - igual que código original""" |
|
return { |
|
"message": "API funcionando", |
|
"timestamp": datetime.now().isoformat(), |
|
"test": "ok" |
|
} |
|
|
|
if __name__ == "__main__": |
|
logger.info("🚀 Iniciando Musical Instrument Detection API con modelo Janiopi...") |
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=7860, |
|
log_level="info" |
|
) |