import logging from dataclasses import asdict, dataclass from pathlib import Path from typing import Union from omegaconf import OmegaConf from rich.console import Console from rich.panel import Panel from rich.table import Table logger = logging.getLogger(__name__) console = Console() def _make_stft_cfg(hop_length, win_length=None): if win_length is None: win_length = 4 * hop_length n_fft = 2 ** (win_length - 1).bit_length() return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) def _build_rich_table(rows, columns, title=None): table = Table(title=title, header_style=None) for column in columns: table.add_column(column.capitalize(), justify="left") for row in rows: table.add_row(*map(str, row)) return Panel(table, expand=False) def _rich_print_dict(d, title="Config", key="Key", value="Value"): console.print(_build_rich_table(d.items(), [key, value], title)) @dataclass(frozen=True) class HParams: # Dataset fg_dir: Path = Path("data/fg") bg_dir: Path = Path("data/bg") rir_dir: Path = Path("data/rir") load_fg_only: bool = False praat_augment_prob: float = 0 # Audio settings wav_rate: int = 44_100 n_fft: int = 2048 win_size: int = 2048 hop_size: int = 420 # 9.5ms num_mels: int = 128 stft_magnitude_min: float = 1e-4 preemphasis: float = 0.97 mix_alpha_range: tuple[float, float] = (0.2, 0.8) # Training nj: int = 64 training_seconds: float = 1.0 batch_size_per_gpu: int = 16 min_lr: float = 1e-5 max_lr: float = 1e-4 warmup_steps: int = 1000 max_steps: int = 1_000_000 gradient_clipping: float = 1.0 @property def deepspeed_config(self): return { "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, "optimizer": { "type": "Adam", "params": {"lr": float(self.min_lr)}, }, "scheduler": { "type": "WarmupDecayLR", "params": { "warmup_min_lr": float(self.min_lr), "warmup_max_lr": float(self.max_lr), "warmup_num_steps": self.warmup_steps, "total_num_steps": self.max_steps, "warmup_type": "linear", }, }, "gradient_clipping": self.gradient_clipping, } @property def stft_cfgs(self): assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" return [_make_stft_cfg(h) for h in (100, 256, 512)] @classmethod def from_yaml(cls, path: Path) -> "HParams": logger.info(f"Reading hparams from {path}") # First merge to fix types (e.g., str -> Path) return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) def save_if_not_exists(self, run_dir: Path): path = run_dir / "hparams.yaml" if path.exists(): logger.info(f"{path} already exists, not saving") return path.parent.mkdir(parents=True, exist_ok=True) OmegaConf.save(asdict(self), str(path)) @classmethod def load(cls, run_dir, yaml: Union[Path, None] = None): hps = [] if (run_dir / "hparams.yaml").exists(): hps.append(cls.from_yaml(run_dir / "hparams.yaml")) if yaml is not None: hps.append(cls.from_yaml(yaml)) if len(hps) == 0: hps.append(cls()) for hp in hps[1:]: if hp != hps[0]: errors = {} for k, v in asdict(hp).items(): if getattr(hps[0], k) != v: errors[k] = f"{getattr(hps[0], k)} != {v}" raise ValueError( f"Found inconsistent hparams: {errors}, consider deleting {run_dir}" ) return hps[0] def print(self): _rich_print_dict(asdict(self), title="HParams")