import logging from dataclasses import asdict, dataclass from pathlib import Path from omegaconf import OmegaConf from rich.console import Console from rich.panel import Panel from rich.table import Table logger = logging.getLogger(__name__) console = Console() def _make_stft_cfg(hop_length, win_length=None): if win_length is None: win_length = 4 * hop_length n_fft = 2 ** (win_length - 1).bit_length() return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) def _build_rich_table(rows, columns, title=None): table = Table(title=title, header_style=None) for column in columns: table.add_column(column.capitalize(), justify="left") for row in rows: table.add_row(*map(str, row)) return Panel(table, expand=False) def _rich_print_dict(d, title="Config", key="Key", value="Value"): console.print(_build_rich_table(d.items(), [key, value], title)) @dataclass(frozen=True) class HParams: # Dataset fg_dir: Path = Path("data/fg") bg_dir: Path = Path("data/bg") rir_dir: Path = Path("data/rir") load_fg_only: bool = False praat_augment_prob: float = 0 # Audio settings wav_rate: int = 44_100 n_fft: int = 2048 win_size: int = 2048 hop_size: int = 420 # 9.5ms num_mels: int = 128 stft_magnitude_min: float = 1e-4 preemphasis: float = 0.97 mix_alpha_range: tuple[float, float] = (0.2, 0.8) # Training nj: int = 64 training_seconds: float = 1.0 batch_size_per_gpu: int = 16 min_lr: float = 1e-5 max_lr: float = 1e-4 warmup_steps: int = 1000 max_steps: int = 1_000_000 gradient_clipping: float = 1.0 @property def deepspeed_config(self): return { "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, "optimizer": { "type": "Adam", "params": {"lr": float(self.min_lr)}, }, "scheduler": { "type": "WarmupDecayLR", "params": { "warmup_min_lr": float(self.min_lr), "warmup_max_lr": float(self.max_lr), "warmup_num_steps": self.warmup_steps, "total_num_steps": self.max_steps, "warmup_type": "linear", }, }, "gradient_clipping": self.gradient_clipping, } @property def stft_cfgs(self): assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" return [_make_stft_cfg(h) for h in (100, 256, 512)] @classmethod def from_yaml(cls, path: Path) -> "HParams": logger.info(f"Reading hparams from {path}") # First merge to fix types (e.g., str -> Path) return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) def save_if_not_exists(self, run_dir: Path): path = run_dir / "hparams.yaml" if path.exists(): logger.info(f"{path} already exists, not saving") return path.parent.mkdir(parents=True, exist_ok=True) OmegaConf.save(asdict(self), str(path)) @classmethod def load(cls, run_dir, yaml: Path | None = None): hps = [] if (run_dir / "hparams.yaml").exists(): hps.append(cls.from_yaml(run_dir / "hparams.yaml")) if yaml is not None: hps.append(cls.from_yaml(yaml)) if len(hps) == 0: hps.append(cls()) for hp in hps[1:]: if hp != hps[0]: errors = {} for k, v in asdict(hp).items(): if getattr(hps[0], k) != v: errors[k] = f"{getattr(hps[0], k)} != {v}" raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") return hps[0] def print(self): _rich_print_dict(asdict(self), title="HParams")