import logging from functools import cache from pathlib import Path from typing import Union import torch from ..inference import inference from .download import download from .hparams import HParams from .enhancer import Enhancer logger = logging.getLogger(__name__) @cache def load_enhancer(run_dir: Union[str, Path, None], device): run_dir = download(run_dir) hp = HParams.load(run_dir) enhancer = Enhancer(hp) path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" state_dict = torch.load(path, map_location="cpu")["module"] enhancer.load_state_dict(state_dict) enhancer.eval() enhancer.to(device) return enhancer @torch.inference_mode() def denoise(dwav, sr, device, run_dir=None): enhancer = load_enhancer(run_dir, device) return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) @torch.inference_mode() def enhance( dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None ): assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" assert solver in ( "midpoint", "rk4", "euler", ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" enhancer = load_enhancer(run_dir, device) enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) return inference(model=enhancer, dwav=dwav, sr=sr, device=device)