""" Configuration file for Enhanced Audio Separator Demo Centralized settings for easy customization """ import os from dataclasses import dataclass from typing import Dict, List, Optional @dataclass class ModelConfig: """Configuration for AI models""" default_model: str = "model_bs_roformer_ep_317_sdr_12.9755.ckpt" model_cache_dir: str = "/tmp/audio-separator-models/" auto_download_models: bool = True max_concurrent_downloads: int = 3 @dataclass class ProcessingConfig: """Configuration for audio processing""" default_quality_preset: str = "Standard" supported_formats: List[str] = None max_file_size_mb: int = 500 max_processing_time_minutes: int = 30 auto_cleanup_temp_files: bool = True def __post_init__(self): if self.supported_formats is None: self.supported_formats = ["wav", "mp3", "flac", "ogg", "opus", "m4a", "aiff"] @dataclass class UIConfig: """Configuration for user interface""" title: str = "Enhanced Audio Separator" description: str = "Advanced audio source separation powered by AI" theme: str = "soft" # soft, base, default show_system_info: bool = True max_history_entries: int = 50 enable_batch_processing: bool = True enable_model_comparison: bool = True max_comparison_models: int = 5 @dataclass class HardwareConfig: """Configuration for hardware acceleration""" enable_cuda: bool = True enable_mps: bool = True # Apple Silicon enable_directml: bool = False # Windows autocast_enabled: bool = True soundfile_enabled: bool = True memory_optimization: bool = True @dataclass class PerformanceConfig: """Configuration for performance optimization""" default_batch_size: int = 1 default_segment_size: int = 256 default_overlap: float = 0.25 gpu_memory_fraction: float = 0.8 clear_cache_interval: int = 5 # Clear GPU cache every N operations @dataclass class SecurityConfig: """Configuration for security and safety""" max_file_size_bytes: int = 500 * 1024 * 1024 # 500MB allowed_file_extensions: List[str] = None temp_file_retention_hours: int = 24 enable_file_validation: bool = True def __post_init__(self): if self.allowed_file_extensions is None: self.allowed_file_extensions = [".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aiff", ".m4a"] class Config: """Main configuration class""" def __init__(self): self.model = ModelConfig() self.processing = ProcessingConfig() self.ui = UIConfig() self.hardware = HardwareConfig() self.performance = PerformanceConfig() self.security = SecurityConfig() # Load environment variables if they exist self._load_from_env() def _load_from_env(self): """Load configuration from environment variables""" # Model configuration if os.getenv('AUDIO_SEPARATOR_MODEL_DIR'): self.model.model_cache_dir = os.getenv('AUDIO_SEPARATOR_MODEL_DIR') # Processing configuration if os.getenv('MAX_FILE_SIZE_MB'): try: self.processing.max_file_size_mb = int(os.getenv('MAX_FILE_SIZE_MB')) except ValueError: pass if os.getenv('AUTO_DOWNLOAD_MODELS'): self.model.auto_download_models = os.getenv('AUTO_DOWNLOAD_MODELS').lower() == 'true' # Performance configuration if os.getenv('DEFAULT_BATCH_SIZE'): try: self.performance.default_batch_size = int(os.getenv('DEFAULT_BATCH_SIZE')) except ValueError: pass # UI configuration if os.getenv('ENABLE_BATCH_PROCESSING'): self.ui.enable_batch_processing = os.getenv('ENABLE_BATCH_PROCESSING').lower() == 'true' if os.getenv('ENABLE_MODEL_COMPARISON'): self.ui.enable_model_comparison = os.getenv('ENABLE_MODEL_COMPARISON').lower() == 'true' def get_model_presets(self) -> Dict[str, Dict]: """Get predefined model processing presets""" return { "Fast": { "mdx_params": { "batch_size": 4, "segment_size": 128, "overlap": 0.1, "enable_denoise": False }, "vr_params": { "batch_size": 8, "aggression": 3, "enable_tta": False, "enable_post_process": False }, "demucs_params": { "shifts": 1, "overlap": 0.1, "segments_enabled": True }, "mdxc_params": { "batch_size": 4, "overlap": 4, "pitch_shift": 0 } }, "Standard": { "mdx_params": { "batch_size": 1, "segment_size": 256, "overlap": 0.25, "enable_denoise": False }, "vr_params": { "batch_size": 1, "aggression": 5, "enable_tta": False, "enable_post_process": False }, "demucs_params": { "shifts": 2, "overlap": 0.25, "segments_enabled": True }, "mdxc_params": { "batch_size": 1, "overlap": 8, "pitch_shift": 0 } }, "High Quality": { "mdx_params": { "batch_size": 1, "segment_size": 512, "overlap": 0.5, "enable_denoise": True }, "vr_params": { "batch_size": 1, "aggression": 8, "enable_tta": True, "enable_post_process": True, "post_process_threshold": 0.2, "high_end_process": True }, "demucs_params": { "shifts": 4, "overlap": 0.5, "segments_enabled": False }, "mdxc_params": { "batch_size": 1, "overlap": 16, "pitch_shift": 0 } } } def validate_file(self, file_path: str) -> bool: """Validate uploaded file""" if not self.security.enable_file_validation: return True # Check file size if os.path.getsize(file_path) > self.security.max_file_size_bytes: return False # Check file extension file_ext = os.path.splitext(file_path)[1].lower() return file_ext in self.security.allowed_file_extensions # Global configuration instance config = Config() # Helper functions for easy access def get_default_model() -> str: """Get the default model name""" return config.model.default_model def get_model_preset(preset_name: str) -> Dict: """Get processing parameters for a preset""" presets = config.get_model_presets() return presets.get(preset_name, presets["Standard"]) def is_gpu_available() -> bool: """Check if GPU acceleration is available""" import torch return torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) def get_optimal_settings() -> Dict: """Get optimal settings based on available hardware""" if is_gpu_available(): return { "use_autocast": True, "use_soundfile": True, "batch_size": 1 if torch.cuda.is_available() else 2 } else: return { "use_autocast": False, "use_soundfile": True, "batch_size": 1 }