Spaces:
Runtime error
Runtime error
import logging | |
from dataclasses import asdict, dataclass | |
from pathlib import Path | |
from omegaconf import OmegaConf | |
from rich.console import Console | |
from rich.panel import Panel | |
from rich.table import Table | |
logger = logging.getLogger(__name__) | |
console = Console() | |
def _make_stft_cfg(hop_length, win_length=None): | |
if win_length is None: | |
win_length = 4 * hop_length | |
n_fft = 2 ** (win_length - 1).bit_length() | |
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) | |
def _build_rich_table(rows, columns, title=None): | |
table = Table(title=title, header_style=None) | |
for column in columns: | |
table.add_column(column.capitalize(), justify="left") | |
for row in rows: | |
table.add_row(*map(str, row)) | |
return Panel(table, expand=False) | |
def _rich_print_dict(d, title="Config", key="Key", value="Value"): | |
console.print(_build_rich_table(d.items(), [key, value], title)) | |
class HParams: | |
# Dataset | |
fg_dir: Path = Path("data/fg") | |
bg_dir: Path = Path("data/bg") | |
rir_dir: Path = Path("data/rir") | |
load_fg_only: bool = False | |
praat_augment_prob: float = 0 | |
# Audio settings | |
wav_rate: int = 44_100 | |
n_fft: int = 2048 | |
win_size: int = 2048 | |
hop_size: int = 420 # 9.5ms | |
num_mels: int = 128 | |
stft_magnitude_min: float = 1e-4 | |
preemphasis: float = 0.97 | |
mix_alpha_range: tuple[float, float] = (0.2, 0.8) | |
# Training | |
nj: int = 64 | |
training_seconds: float = 1.0 | |
batch_size_per_gpu: int = 16 | |
min_lr: float = 1e-5 | |
max_lr: float = 1e-4 | |
warmup_steps: int = 1000 | |
max_steps: int = 1_000_000 | |
gradient_clipping: float = 1.0 | |
def deepspeed_config(self): | |
return { | |
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, | |
"optimizer": { | |
"type": "Adam", | |
"params": {"lr": float(self.min_lr)}, | |
}, | |
"scheduler": { | |
"type": "WarmupDecayLR", | |
"params": { | |
"warmup_min_lr": float(self.min_lr), | |
"warmup_max_lr": float(self.max_lr), | |
"warmup_num_steps": self.warmup_steps, | |
"total_num_steps": self.max_steps, | |
"warmup_type": "linear", | |
}, | |
}, | |
"gradient_clipping": self.gradient_clipping, | |
} | |
def stft_cfgs(self): | |
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" | |
return [_make_stft_cfg(h) for h in (100, 256, 512)] | |
def from_yaml(cls, path: Path) -> "HParams": | |
logger.info(f"Reading hparams from {path}") | |
# First merge to fix types (e.g., str -> Path) | |
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) | |
def save_if_not_exists(self, run_dir: Path): | |
path = run_dir / "hparams.yaml" | |
if path.exists(): | |
logger.info(f"{path} already exists, not saving") | |
return | |
path.parent.mkdir(parents=True, exist_ok=True) | |
OmegaConf.save(asdict(self), str(path)) | |
def load(cls, run_dir, yaml: Path | None = None): | |
hps = [] | |
if (run_dir / "hparams.yaml").exists(): | |
hps.append(cls.from_yaml(run_dir / "hparams.yaml")) | |
if yaml is not None: | |
hps.append(cls.from_yaml(yaml)) | |
if len(hps) == 0: | |
hps.append(cls()) | |
for hp in hps[1:]: | |
if hp != hps[0]: | |
errors = {} | |
for k, v in asdict(hp).items(): | |
if getattr(hps[0], k) != v: | |
errors[k] = f"{getattr(hps[0], k)} != {v}" | |
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") | |
return hps[0] | |
def print(self): | |
_rich_print_dict(asdict(self), title="HParams") | |