audio-separator / config.py
NeoPy's picture
Upload 10 files
18c13fa verified
"""
Configuration file for Enhanced Audio Separator Demo
Centralized settings for easy customization
"""
import os
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class ModelConfig:
"""Configuration for AI models"""
default_model: str = "model_bs_roformer_ep_317_sdr_12.9755.ckpt"
model_cache_dir: str = "/tmp/audio-separator-models/"
auto_download_models: bool = True
max_concurrent_downloads: int = 3
@dataclass
class ProcessingConfig:
"""Configuration for audio processing"""
default_quality_preset: str = "Standard"
supported_formats: List[str] = None
max_file_size_mb: int = 500
max_processing_time_minutes: int = 30
auto_cleanup_temp_files: bool = True
def __post_init__(self):
if self.supported_formats is None:
self.supported_formats = ["wav", "mp3", "flac", "ogg", "opus", "m4a", "aiff"]
@dataclass
class UIConfig:
"""Configuration for user interface"""
title: str = "Enhanced Audio Separator"
description: str = "Advanced audio source separation powered by AI"
theme: str = "soft" # soft, base, default
show_system_info: bool = True
max_history_entries: int = 50
enable_batch_processing: bool = True
enable_model_comparison: bool = True
max_comparison_models: int = 5
@dataclass
class HardwareConfig:
"""Configuration for hardware acceleration"""
enable_cuda: bool = True
enable_mps: bool = True # Apple Silicon
enable_directml: bool = False # Windows
autocast_enabled: bool = True
soundfile_enabled: bool = True
memory_optimization: bool = True
@dataclass
class PerformanceConfig:
"""Configuration for performance optimization"""
default_batch_size: int = 1
default_segment_size: int = 256
default_overlap: float = 0.25
gpu_memory_fraction: float = 0.8
clear_cache_interval: int = 5 # Clear GPU cache every N operations
@dataclass
class SecurityConfig:
"""Configuration for security and safety"""
max_file_size_bytes: int = 500 * 1024 * 1024 # 500MB
allowed_file_extensions: List[str] = None
temp_file_retention_hours: int = 24
enable_file_validation: bool = True
def __post_init__(self):
if self.allowed_file_extensions is None:
self.allowed_file_extensions = [".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aiff", ".m4a"]
class Config:
"""Main configuration class"""
def __init__(self):
self.model = ModelConfig()
self.processing = ProcessingConfig()
self.ui = UIConfig()
self.hardware = HardwareConfig()
self.performance = PerformanceConfig()
self.security = SecurityConfig()
# Load environment variables if they exist
self._load_from_env()
def _load_from_env(self):
"""Load configuration from environment variables"""
# Model configuration
if os.getenv('AUDIO_SEPARATOR_MODEL_DIR'):
self.model.model_cache_dir = os.getenv('AUDIO_SEPARATOR_MODEL_DIR')
# Processing configuration
if os.getenv('MAX_FILE_SIZE_MB'):
try:
self.processing.max_file_size_mb = int(os.getenv('MAX_FILE_SIZE_MB'))
except ValueError:
pass
if os.getenv('AUTO_DOWNLOAD_MODELS'):
self.model.auto_download_models = os.getenv('AUTO_DOWNLOAD_MODELS').lower() == 'true'
# Performance configuration
if os.getenv('DEFAULT_BATCH_SIZE'):
try:
self.performance.default_batch_size = int(os.getenv('DEFAULT_BATCH_SIZE'))
except ValueError:
pass
# UI configuration
if os.getenv('ENABLE_BATCH_PROCESSING'):
self.ui.enable_batch_processing = os.getenv('ENABLE_BATCH_PROCESSING').lower() == 'true'
if os.getenv('ENABLE_MODEL_COMPARISON'):
self.ui.enable_model_comparison = os.getenv('ENABLE_MODEL_COMPARISON').lower() == 'true'
def get_model_presets(self) -> Dict[str, Dict]:
"""Get predefined model processing presets"""
return {
"Fast": {
"mdx_params": {
"batch_size": 4,
"segment_size": 128,
"overlap": 0.1,
"enable_denoise": False
},
"vr_params": {
"batch_size": 8,
"aggression": 3,
"enable_tta": False,
"enable_post_process": False
},
"demucs_params": {
"shifts": 1,
"overlap": 0.1,
"segments_enabled": True
},
"mdxc_params": {
"batch_size": 4,
"overlap": 4,
"pitch_shift": 0
}
},
"Standard": {
"mdx_params": {
"batch_size": 1,
"segment_size": 256,
"overlap": 0.25,
"enable_denoise": False
},
"vr_params": {
"batch_size": 1,
"aggression": 5,
"enable_tta": False,
"enable_post_process": False
},
"demucs_params": {
"shifts": 2,
"overlap": 0.25,
"segments_enabled": True
},
"mdxc_params": {
"batch_size": 1,
"overlap": 8,
"pitch_shift": 0
}
},
"High Quality": {
"mdx_params": {
"batch_size": 1,
"segment_size": 512,
"overlap": 0.5,
"enable_denoise": True
},
"vr_params": {
"batch_size": 1,
"aggression": 8,
"enable_tta": True,
"enable_post_process": True,
"post_process_threshold": 0.2,
"high_end_process": True
},
"demucs_params": {
"shifts": 4,
"overlap": 0.5,
"segments_enabled": False
},
"mdxc_params": {
"batch_size": 1,
"overlap": 16,
"pitch_shift": 0
}
}
}
def validate_file(self, file_path: str) -> bool:
"""Validate uploaded file"""
if not self.security.enable_file_validation:
return True
# Check file size
if os.path.getsize(file_path) > self.security.max_file_size_bytes:
return False
# Check file extension
file_ext = os.path.splitext(file_path)[1].lower()
return file_ext in self.security.allowed_file_extensions
# Global configuration instance
config = Config()
# Helper functions for easy access
def get_default_model() -> str:
"""Get the default model name"""
return config.model.default_model
def get_model_preset(preset_name: str) -> Dict:
"""Get processing parameters for a preset"""
presets = config.get_model_presets()
return presets.get(preset_name, presets["Standard"])
def is_gpu_available() -> bool:
"""Check if GPU acceleration is available"""
import torch
return torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
def get_optimal_settings() -> Dict:
"""Get optimal settings based on available hardware"""
if is_gpu_available():
return {
"use_autocast": True,
"use_soundfile": True,
"batch_size": 1 if torch.cuda.is_available() else 2
}
else:
return {
"use_autocast": False,
"use_soundfile": True,
"batch_size": 1
}