Spaces:
Build error
Build error
| """ | |
| Configuration file for Enhanced Audio Separator Demo | |
| Centralized settings for easy customization | |
| """ | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| class ModelConfig: | |
| """Configuration for AI models""" | |
| default_model: str = "model_bs_roformer_ep_317_sdr_12.9755.ckpt" | |
| model_cache_dir: str = "/tmp/audio-separator-models/" | |
| auto_download_models: bool = True | |
| max_concurrent_downloads: int = 3 | |
| class ProcessingConfig: | |
| """Configuration for audio processing""" | |
| default_quality_preset: str = "Standard" | |
| supported_formats: List[str] = None | |
| max_file_size_mb: int = 500 | |
| max_processing_time_minutes: int = 30 | |
| auto_cleanup_temp_files: bool = True | |
| def __post_init__(self): | |
| if self.supported_formats is None: | |
| self.supported_formats = ["wav", "mp3", "flac", "ogg", "opus", "m4a", "aiff"] | |
| class UIConfig: | |
| """Configuration for user interface""" | |
| title: str = "Enhanced Audio Separator" | |
| description: str = "Advanced audio source separation powered by AI" | |
| theme: str = "soft" # soft, base, default | |
| show_system_info: bool = True | |
| max_history_entries: int = 50 | |
| enable_batch_processing: bool = True | |
| enable_model_comparison: bool = True | |
| max_comparison_models: int = 5 | |
| class HardwareConfig: | |
| """Configuration for hardware acceleration""" | |
| enable_cuda: bool = True | |
| enable_mps: bool = True # Apple Silicon | |
| enable_directml: bool = False # Windows | |
| autocast_enabled: bool = True | |
| soundfile_enabled: bool = True | |
| memory_optimization: bool = True | |
| class PerformanceConfig: | |
| """Configuration for performance optimization""" | |
| default_batch_size: int = 1 | |
| default_segment_size: int = 256 | |
| default_overlap: float = 0.25 | |
| gpu_memory_fraction: float = 0.8 | |
| clear_cache_interval: int = 5 # Clear GPU cache every N operations | |
| class SecurityConfig: | |
| """Configuration for security and safety""" | |
| max_file_size_bytes: int = 500 * 1024 * 1024 # 500MB | |
| allowed_file_extensions: List[str] = None | |
| temp_file_retention_hours: int = 24 | |
| enable_file_validation: bool = True | |
| def __post_init__(self): | |
| if self.allowed_file_extensions is None: | |
| self.allowed_file_extensions = [".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aiff", ".m4a"] | |
| class Config: | |
| """Main configuration class""" | |
| def __init__(self): | |
| self.model = ModelConfig() | |
| self.processing = ProcessingConfig() | |
| self.ui = UIConfig() | |
| self.hardware = HardwareConfig() | |
| self.performance = PerformanceConfig() | |
| self.security = SecurityConfig() | |
| # Load environment variables if they exist | |
| self._load_from_env() | |
| def _load_from_env(self): | |
| """Load configuration from environment variables""" | |
| # Model configuration | |
| if os.getenv('AUDIO_SEPARATOR_MODEL_DIR'): | |
| self.model.model_cache_dir = os.getenv('AUDIO_SEPARATOR_MODEL_DIR') | |
| # Processing configuration | |
| if os.getenv('MAX_FILE_SIZE_MB'): | |
| try: | |
| self.processing.max_file_size_mb = int(os.getenv('MAX_FILE_SIZE_MB')) | |
| except ValueError: | |
| pass | |
| if os.getenv('AUTO_DOWNLOAD_MODELS'): | |
| self.model.auto_download_models = os.getenv('AUTO_DOWNLOAD_MODELS').lower() == 'true' | |
| # Performance configuration | |
| if os.getenv('DEFAULT_BATCH_SIZE'): | |
| try: | |
| self.performance.default_batch_size = int(os.getenv('DEFAULT_BATCH_SIZE')) | |
| except ValueError: | |
| pass | |
| # UI configuration | |
| if os.getenv('ENABLE_BATCH_PROCESSING'): | |
| self.ui.enable_batch_processing = os.getenv('ENABLE_BATCH_PROCESSING').lower() == 'true' | |
| if os.getenv('ENABLE_MODEL_COMPARISON'): | |
| self.ui.enable_model_comparison = os.getenv('ENABLE_MODEL_COMPARISON').lower() == 'true' | |
| def get_model_presets(self) -> Dict[str, Dict]: | |
| """Get predefined model processing presets""" | |
| return { | |
| "Fast": { | |
| "mdx_params": { | |
| "batch_size": 4, | |
| "segment_size": 128, | |
| "overlap": 0.1, | |
| "enable_denoise": False | |
| }, | |
| "vr_params": { | |
| "batch_size": 8, | |
| "aggression": 3, | |
| "enable_tta": False, | |
| "enable_post_process": False | |
| }, | |
| "demucs_params": { | |
| "shifts": 1, | |
| "overlap": 0.1, | |
| "segments_enabled": True | |
| }, | |
| "mdxc_params": { | |
| "batch_size": 4, | |
| "overlap": 4, | |
| "pitch_shift": 0 | |
| } | |
| }, | |
| "Standard": { | |
| "mdx_params": { | |
| "batch_size": 1, | |
| "segment_size": 256, | |
| "overlap": 0.25, | |
| "enable_denoise": False | |
| }, | |
| "vr_params": { | |
| "batch_size": 1, | |
| "aggression": 5, | |
| "enable_tta": False, | |
| "enable_post_process": False | |
| }, | |
| "demucs_params": { | |
| "shifts": 2, | |
| "overlap": 0.25, | |
| "segments_enabled": True | |
| }, | |
| "mdxc_params": { | |
| "batch_size": 1, | |
| "overlap": 8, | |
| "pitch_shift": 0 | |
| } | |
| }, | |
| "High Quality": { | |
| "mdx_params": { | |
| "batch_size": 1, | |
| "segment_size": 512, | |
| "overlap": 0.5, | |
| "enable_denoise": True | |
| }, | |
| "vr_params": { | |
| "batch_size": 1, | |
| "aggression": 8, | |
| "enable_tta": True, | |
| "enable_post_process": True, | |
| "post_process_threshold": 0.2, | |
| "high_end_process": True | |
| }, | |
| "demucs_params": { | |
| "shifts": 4, | |
| "overlap": 0.5, | |
| "segments_enabled": False | |
| }, | |
| "mdxc_params": { | |
| "batch_size": 1, | |
| "overlap": 16, | |
| "pitch_shift": 0 | |
| } | |
| } | |
| } | |
| def validate_file(self, file_path: str) -> bool: | |
| """Validate uploaded file""" | |
| if not self.security.enable_file_validation: | |
| return True | |
| # Check file size | |
| if os.path.getsize(file_path) > self.security.max_file_size_bytes: | |
| return False | |
| # Check file extension | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| return file_ext in self.security.allowed_file_extensions | |
| # Global configuration instance | |
| config = Config() | |
| # Helper functions for easy access | |
| def get_default_model() -> str: | |
| """Get the default model name""" | |
| return config.model.default_model | |
| def get_model_preset(preset_name: str) -> Dict: | |
| """Get processing parameters for a preset""" | |
| presets = config.get_model_presets() | |
| return presets.get(preset_name, presets["Standard"]) | |
| def is_gpu_available() -> bool: | |
| """Check if GPU acceleration is available""" | |
| import torch | |
| return torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) | |
| def get_optimal_settings() -> Dict: | |
| """Get optimal settings based on available hardware""" | |
| if is_gpu_available(): | |
| return { | |
| "use_autocast": True, | |
| "use_soundfile": True, | |
| "batch_size": 1 if torch.cuda.is_available() else 2 | |
| } | |
| else: | |
| return { | |
| "use_autocast": False, | |
| "use_soundfile": True, | |
| "batch_size": 1 | |
| } |