File size: 790 Bytes
32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import logging
from functools import cache
import torch
from ..denoiser.denoiser import Denoiser
from ..inference import inference
from .hparams import HParams
logger = logging.getLogger(__name__)
@cache
def load_denoiser(run_dir, device):
if run_dir is None:
return Denoiser(HParams())
hp = HParams.load(run_dir)
denoiser = Denoiser(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
denoiser.load_state_dict(state_dict)
denoiser.eval()
denoiser.to(device)
return denoiser
@torch.inference_mode()
def denoise(dwav, sr, run_dir, device):
denoiser = load_denoiser(run_dir, device)
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|