Spaces:
Runtime error
Runtime error
| """Configuration management utilities.""" | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Optional, Dict, Any | |
| import yaml | |
| from pathlib import Path | |
| class ModelConfig: | |
| """Model configuration.""" | |
| name: str = "facebook/wav2vec2-base" | |
| device: str = "cuda" | |
| checkpoint: Optional[str] = None | |
| class RLConfig: | |
| """Reinforcement learning configuration.""" | |
| algorithm: str = "ppo" | |
| learning_rate: float = 3.0e-4 | |
| batch_size: int = 32 | |
| num_episodes: int = 1000 | |
| episode_length: int = 100 | |
| gamma: float = 0.99 | |
| clip_epsilon: float = 0.2 # PPO specific | |
| max_grad_norm: float = 1.0 | |
| class DataConfig: | |
| """Data configuration.""" | |
| dataset_path: str = "data/processed" | |
| train_split: float = 0.7 | |
| val_split: float = 0.15 | |
| test_split: float = 0.15 | |
| sample_rate: int = 16000 | |
| class CurriculumConfig: | |
| """Curriculum learning configuration.""" | |
| enabled: bool = True | |
| levels: int = 5 | |
| advancement_threshold: float = 0.8 | |
| class OptimizationConfig: | |
| """Optimization configuration.""" | |
| mixed_precision: bool = True | |
| gradient_checkpointing: bool = False | |
| class CheckpointConfig: | |
| """Checkpointing configuration.""" | |
| interval: int = 50 # episodes | |
| save_dir: str = "checkpoints" | |
| keep_last_n: int = 5 | |
| class MonitoringConfig: | |
| """Monitoring configuration.""" | |
| log_interval: int = 10 | |
| visualization_interval: int = 50 | |
| tensorboard_dir: str = "runs" | |
| class ReproducibilityConfig: | |
| """Reproducibility configuration.""" | |
| random_seed: int = 42 | |
| class TrainingConfig: | |
| """Complete training configuration.""" | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| rl: RLConfig = field(default_factory=RLConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| curriculum: CurriculumConfig = field(default_factory=CurriculumConfig) | |
| optimization: OptimizationConfig = field(default_factory=OptimizationConfig) | |
| checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig) | |
| monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) | |
| reproducibility: ReproducibilityConfig = field(default_factory=ReproducibilityConfig) | |
| def from_yaml(cls, path: str) -> "TrainingConfig": | |
| """Load configuration from YAML file.""" | |
| with open(path, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| return cls( | |
| model=ModelConfig(**config_dict.get('model', {})), | |
| rl=RLConfig(**config_dict.get('rl', {})), | |
| data=DataConfig(**config_dict.get('data', {})), | |
| curriculum=CurriculumConfig(**config_dict.get('curriculum', {})), | |
| optimization=OptimizationConfig(**config_dict.get('optimization', {})), | |
| checkpointing=CheckpointConfig(**config_dict.get('checkpointing', {})), | |
| monitoring=MonitoringConfig(**config_dict.get('monitoring', {})), | |
| reproducibility=ReproducibilityConfig(**config_dict.get('reproducibility', {})) | |
| ) | |
| def to_yaml(self, path: str) -> None: | |
| """Save configuration to YAML file.""" | |
| config_dict = { | |
| 'model': asdict(self.model), | |
| 'rl': asdict(self.rl), | |
| 'data': asdict(self.data), | |
| 'curriculum': asdict(self.curriculum), | |
| 'optimization': asdict(self.optimization), | |
| 'checkpointing': asdict(self.checkpointing), | |
| 'monitoring': asdict(self.monitoring), | |
| 'reproducibility': asdict(self.reproducibility) | |
| } | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, 'w') as f: | |
| yaml.dump(config_dict, f, default_flow_style=False) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert configuration to dictionary.""" | |
| return { | |
| 'model': asdict(self.model), | |
| 'rl': asdict(self.rl), | |
| 'data': asdict(self.data), | |
| 'curriculum': asdict(self.curriculum), | |
| 'optimization': asdict(self.optimization), | |
| 'checkpointing': asdict(self.checkpointing), | |
| 'monitoring': asdict(self.monitoring), | |
| 'reproducibility': asdict(self.reproducibility) | |
| } | |