Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| API server for Synesthesia runtime. | |
| Provides REST and WebSocket endpoints for controlling the runtime and accessing metrics. | |
| """ | |
| import logging | |
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| from typing import Dict, Any, Optional | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| # Add the project root to the sys.path so we can import the runtime module | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from ML_Pipeline.shared.env import apply_defaults | |
| # Import runtime components (we'll import them as they become available) | |
| try: | |
| from runtime.camera_capture import CameraCapture | |
| from runtime.mic_capture import MicCapture | |
| from runtime.gemma_prompt_engine import GemmaPromptEngine | |
| from runtime.magenta_generation import MagentaGeneration | |
| from runtime.clip_scheduler import ClipScheduler | |
| except ImportError as e: | |
| logging.warning(f"Some runtime modules not available: {e}") | |
| # We'll create placeholder classes for now | |
| CameraCapture = MicCapture = GemmaPromptEngine = MagentaGeneration = ClipScheduler = None | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Synesthesia Runtime API", version="0.1.0") | |
| # Apply canonical defaults from environment/dotenv | |
| env_config = apply_defaults() | |
| # CORS configuration | |
| allowed_origins = [o.strip() for o in env_config.get("ALLOWED_ORIGINS", "http://localhost:1420").split(",")] | |
| allow_credentials = env_config.get("ALLOW_CREDENTIALS", "True").lower() == "true" | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_credentials=allow_credentials, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global runtime components | |
| camera: Optional[CameraCapture] = None | |
| mic: Optional[MicCapture] = None | |
| prompt_engine: Optional[GemmaPromptEngine] = None | |
| music_generator: Optional[MagentaGeneration] = None | |
| clip_scheduler: Optional[ClipScheduler] = None | |
| # WebSocket connections for real-time updates | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: list[WebSocket] = [] | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| def disconnect(self, websocket: WebSocket): | |
| self.active_connections.remove(websocket) | |
| async def send_personal_message(self, message: str, websocket: WebSocket): | |
| await websocket.send_text(message) | |
| async def broadcast(self, message: str): | |
| for connection in self.active_connections: | |
| try: | |
| await connection.send_text(message) | |
| except: | |
| # Remove broken connections | |
| self.active_connections.remove(connection) | |
| manager = ConnectionManager() | |
| async def startup_event(): | |
| """Initialize runtime components on startup.""" | |
| global camera, mic, prompt_engine, music_generator, clip_scheduler | |
| logger.info("Starting Synesthesia runtime components...") | |
| try: | |
| # Initialize camera | |
| camera = CameraCapture() | |
| if camera.start(): | |
| logger.info("Camera capture started") | |
| else: | |
| logger.error("Failed to start camera capture") | |
| # Initialize microphone | |
| mic = MicCapture() | |
| if mic.start(): | |
| logger.info("Microphone capture started") | |
| else: | |
| logger.error("Failed to start microphone capture") | |
| # Initialize prompt engine | |
| prompt_engine = GemmaPromptEngine() | |
| if prompt_engine.start(): | |
| logger.info("Gemma prompt engine started") | |
| else: | |
| logger.error("Failed to start Gemma prompt engine") | |
| # Initialize music generator | |
| music_generator = MagentaGeneration() | |
| if music_generator.start(): | |
| logger.info("Magenta RT music generator started") | |
| else: | |
| logger.error("Failed to start Magenta RT music generator") | |
| # Initialize clip scheduler | |
| clip_scheduler = ClipScheduler() | |
| if clip_scheduler.start(): | |
| logger.info("Clip scheduler started") | |
| else: | |
| logger.error("Failed to start clip scheduler") | |
| logger.info("All runtime components initialized") | |
| except Exception as e: | |
| logger.error(f"Error initializing runtime components: {e}") | |
| async def shutdown_event(): | |
| """Cleanup runtime components on shutdown.""" | |
| global camera, mic, prompt_engine, music_generator, clip_scheduler | |
| logger.info("Shutting down Synesthesia runtime components...") | |
| if camera: | |
| camera.stop() | |
| if mic: | |
| mic.stop() | |
| if prompt_engine: | |
| prompt_engine.stop() | |
| if music_generator: | |
| music_generator.stop() | |
| if clip_scheduler: | |
| clip_scheduler.stop() | |
| logger.info("All runtime components stopped") | |
| async def root(): | |
| """Root endpoint.""" | |
| return {"message": "Synesthesia Runtime API", "version": "0.1.0"} | |
| async def get_status(): | |
| """Get status of all runtime components.""" | |
| status = { | |
| "camera": camera.is_running() if camera else False, | |
| "mic": mic.is_running() if mic else False, | |
| "prompt_engine": prompt_engine.is_running() if prompt_engine else False, | |
| "music_generator": music_generator.is_running() if music_generator else False, | |
| "clip_scheduler": clip_scheduler.is_running() if clip_scheduler else False, | |
| } | |
| if clip_scheduler: | |
| status["scheduler_details"] = clip_scheduler.get_status() | |
| return status | |
| # WebSocket endpoint for real-time metrics and updates | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await manager.connect(websocket) | |
| try: | |
| while True: | |
| # Send periodic updates | |
| await asyncio.sleep(1) | |
| # In a real implementation, we would send actual metrics here | |
| await manager.send_personal_message(json.dumps({"type": "ping"}), websocket) | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket) | |
| # Model management endpoints | |
| async def load_model(model_name: str, precision: str = "4bit"): | |
| """Load a model (placeholder).""" | |
| # In a real implementation, this would trigger model loading | |
| return {"message": f"Model {model_name} with precision {precision} loading initiated"} | |
| async def unload_model(model_name: str): | |
| """Unload a model (placeholder).""" | |
| return {"message": f"Model {model_name} unloading initiated"} | |
| async def list_models(): | |
| """List available models (placeholder).""" | |
| return { | |
| "models": [ | |
| {"name": "gemma-3n-e2b", "precisions": ["4bit", "8bit", "fp16"]}, | |
| {"name": "gemma-3n-e4b", "precisions": ["4bit", "8bit", "fp16"]}, | |
| {"name": "magenta-rt-small", "precisions": ["4bit", "8bit", "fp16"]}, | |
| {"name": "magenta-rt-large", "precisions": ["4bit", "8bit", "fp16"]} | |
| ] | |
| } | |
| # Generation control endpoints | |
| async def start_generation(): | |
| """Start continuous music generation.""" | |
| # In a real implementation, this would trigger the generation loop | |
| return {"message": "Continuous generation started"} | |
| async def stop_generation(): | |
| """Stop continuous music generation.""" | |
| # In a real implementation, this would stop the generation loop | |
| return {"message": "Continuous generation stopped"} | |
| async def inject_prompt(prompt: str): | |
| """Inject a custom prompt for the next generation.""" | |
| if clip_scheduler: | |
| clip_scheduler.submit_prompt(prompt) | |
| return {"message": f"Prompt injected: {prompt[:50]}..."} | |
| else: | |
| raise HTTPException(status_code=503, detail="Clip scheduler not available") | |
| # Metrics endpoints | |
| async def get_current_metrics(): | |
| """Get current performance metrics.""" | |
| metrics = {} | |
| if music_generator: | |
| metrics["avg_generation_time"] = music_generator.get_average_generation_time() | |
| if clip_scheduler: | |
| metrics.update(clip_scheduler.get_status()) | |
| return metrics | |
| if __name__ == "__main__": | |
| uvicorn.run("api.server:app", host="0.0.0.0", port=8000, reload=True) |