Spaces:
Sleeping
Sleeping
import logging | |
from functools import cache | |
from pathlib import Path | |
from typing import Union | |
import torch | |
from ..inference import inference | |
from .download import download | |
from .hparams import HParams | |
from .enhancer import Enhancer | |
logger = logging.getLogger(__name__) | |
def load_enhancer(run_dir: Union[str, Path, None], device): | |
run_dir = download(run_dir) | |
hp = HParams.load(run_dir) | |
enhancer = Enhancer(hp) | |
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" | |
state_dict = torch.load(path, map_location="cpu")["module"] | |
enhancer.load_state_dict(state_dict) | |
enhancer.eval() | |
enhancer.to(device) | |
return enhancer | |
def denoise(dwav, sr, device, run_dir=None): | |
enhancer = load_enhancer(run_dir, device) | |
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) | |
def enhance( | |
dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None | |
): | |
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" | |
assert solver in ( | |
"midpoint", | |
"rk4", | |
"euler", | |
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" | |
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" | |
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" | |
enhancer = load_enhancer(run_dir, device) | |
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) | |
return inference(model=enhancer, dwav=dwav, sr=sr, device=device) | |