Wendyellé Abubakrh Alban NYANTUDRE
deleted parent dir resemble-enhance
689d78f
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from omegaconf import OmegaConf
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
logger = logging.getLogger(__name__)
console = Console()
def _make_stft_cfg(hop_length, win_length=None):
if win_length is None:
win_length = 4 * hop_length
n_fft = 2 ** (win_length - 1).bit_length()
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _build_rich_table(rows, columns, title=None):
table = Table(title=title, header_style=None)
for column in columns:
table.add_column(column.capitalize(), justify="left")
for row in rows:
table.add_row(*map(str, row))
return Panel(table, expand=False)
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
console.print(_build_rich_table(d.items(), [key, value], title))
@dataclass(frozen=True)
class HParams:
# Dataset
fg_dir: Path = Path("data/fg")
bg_dir: Path = Path("data/bg")
rir_dir: Path = Path("data/rir")
load_fg_only: bool = False
praat_augment_prob: float = 0
# Audio settings
wav_rate: int = 44_100
n_fft: int = 2048
win_size: int = 2048
hop_size: int = 420 # 9.5ms
num_mels: int = 128
stft_magnitude_min: float = 1e-4
preemphasis: float = 0.97
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
# Training
nj: int = 64
training_seconds: float = 1.0
batch_size_per_gpu: int = 16
min_lr: float = 1e-5
max_lr: float = 1e-4
warmup_steps: int = 1000
max_steps: int = 1_000_000
gradient_clipping: float = 1.0
@property
def deepspeed_config(self):
return {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"optimizer": {
"type": "Adam",
"params": {"lr": float(self.min_lr)},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": float(self.min_lr),
"warmup_max_lr": float(self.max_lr),
"warmup_num_steps": self.warmup_steps,
"total_num_steps": self.max_steps,
"warmup_type": "linear",
},
},
"gradient_clipping": self.gradient_clipping,
}
@property
def stft_cfgs(self):
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
return [_make_stft_cfg(h) for h in (100, 256, 512)]
@classmethod
def from_yaml(cls, path: Path) -> "HParams":
logger.info(f"Reading hparams from {path}")
# First merge to fix types (e.g., str -> Path)
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
def save_if_not_exists(self, run_dir: Path):
path = run_dir / "hparams.yaml"
if path.exists():
logger.info(f"{path} already exists, not saving")
return
path.parent.mkdir(parents=True, exist_ok=True)
OmegaConf.save(asdict(self), str(path))
@classmethod
def load(cls, run_dir, yaml: Path | None = None):
hps = []
if (run_dir / "hparams.yaml").exists():
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
if yaml is not None:
hps.append(cls.from_yaml(yaml))
if len(hps) == 0:
hps.append(cls())
for hp in hps[1:]:
if hp != hps[0]:
errors = {}
for k, v in asdict(hp).items():
if getattr(hps[0], k) != v:
errors[k] = f"{getattr(hps[0], k)} != {v}"
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
return hps[0]
def print(self):
_rich_print_dict(asdict(self), title="HParams")