Spaces:
Sleeping
Sleeping
import logging | |
from typing import Union | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import torch | |
from torch import Tensor, nn | |
from torch.distributions import Beta | |
from ..common import Normalizer | |
from ..denoiser.inference import load_denoiser | |
from ..melspec import MelSpectrogram | |
from .hparams import HParams | |
from .lcfm import CFM, IRMAE, LCFM | |
from .univnet import UnivNet | |
logger = logging.getLogger(__name__) | |
def _maybe(fn): | |
def _fn(*args): | |
if args[0] is None: | |
return None | |
return fn(*args) | |
return _fn | |
def _normalize_wav(x: Tensor): | |
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) | |
class Enhancer(nn.Module): | |
def __init__(self, hp: HParams): | |
super().__init__() | |
self.hp = hp | |
n_mels = self.hp.num_mels | |
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim | |
latent_dim = self.hp.lcfm_latent_dim | |
self.lcfm = LCFM( | |
IRMAE( | |
input_dim=n_mels, | |
output_dim=vocoder_input_dim, | |
latent_dim=latent_dim, | |
), | |
CFM( | |
cond_dim=n_mels, | |
output_dim=self.hp.lcfm_latent_dim, | |
solver_nfe=self.hp.cfm_solver_nfe, | |
solver_method=self.hp.cfm_solver_method, | |
time_mapping_divisor=self.hp.cfm_time_mapping_divisor, | |
), | |
z_scale=self.hp.lcfm_z_scale, | |
) | |
self.lcfm.set_mode_(self.hp.lcfm_training_mode) | |
self.mel_fn = MelSpectrogram(hp) | |
self.vocoder = UnivNet(self.hp, vocoder_input_dim) | |
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu") | |
self.normalizer = Normalizer() | |
self._eval_lambd = 0.0 | |
self.dummy: Tensor | |
self.register_buffer("dummy", torch.zeros(1)) | |
if self.hp.enhancer_stage1_run_dir is not None: | |
pretrained_path = ( | |
self.hp.enhancer_stage1_run_dir | |
/ "ds/G/default/mp_rank_00_model_states.pt" | |
) | |
self._load_pretrained(pretrained_path) | |
# logger.info(f"{self.__class__.__name__} summary") | |
# logger.info(f"{self.summarize()}") | |
def _load_pretrained(self, path): | |
# Clone is necessary as otherwise it holds a reference to the original model | |
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()} | |
denoiser_state_dict = { | |
k: v.clone() for k, v in self.denoiser.state_dict().items() | |
} | |
state_dict = torch.load(path, map_location="cpu")["module"] | |
self.load_state_dict(state_dict, strict=False) | |
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm | |
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser | |
logger.info(f"Loaded pretrained model from {path}") | |
def summarize(self): | |
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad) | |
npa = lambda m: sum(p.numel() for p in m.parameters()) | |
rows = [] | |
for name, module in self.named_children(): | |
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module))) | |
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self))) | |
df = pd.DataFrame(rows) | |
return df.to_markdown(index=False) | |
def to_mel(self, x: Tensor, drop_last=True): | |
""" | |
Args: | |
x: (b t), wavs | |
Returns: | |
o: (b c t), mels | |
""" | |
if drop_last: | |
return self.mel_fn(x)[..., :-1] # (b d t) | |
return self.mel_fn(x) | |
def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None): | |
if self.hp.lcfm_training_mode == "cfm": | |
return self.denoiser(x, y) | |
return x | |
def configurate_(self, nfe, solver, lambd, tau): | |
""" | |
Args: | |
nfe: number of function evaluations | |
solver: solver method | |
lambd: denoiser strength [0, 1] | |
tau: prior temperature [0, 1] | |
""" | |
self.lcfm.cfm.solver.configurate_(nfe, solver) | |
self.lcfm.eval_tau_(tau) | |
self._eval_lambd = lambd | |
def forward( | |
self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None | |
): | |
""" | |
Args: | |
x: (b t), mix wavs (fg + bg) | |
y: (b t), fg clean wavs | |
z: (b t), fg distorted wavs | |
Returns: | |
o: (b t), reconstructed wavs | |
""" | |
assert x.dim() == 2, f"Expected (b t), got {x.size()}" | |
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}" | |
if self.hp.lcfm_training_mode == "cfm": | |
self.normalizer.eval() | |
x = _normalize_wav(x) | |
y = _maybe(_normalize_wav)(y) | |
z = _maybe(_normalize_wav)(z) | |
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t) | |
if self.hp.lcfm_training_mode == "cfm": | |
if self.training: | |
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device) | |
lambd = lambd[:, None, None] | |
x_mel_denoised = self.normalizer( | |
self.to_mel(self._may_denoise(x, z)), update=False | |
) | |
x_mel_denoised = x_mel_denoised.detach() | |
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original | |
self._visualize(x_mel_original, x_mel_denoised) | |
else: | |
lambd = self._eval_lambd | |
if lambd == 0: | |
x_mel_denoised = x_mel_original | |
else: | |
x_mel_denoised = self.normalizer( | |
self.to_mel(self._may_denoise(x, z)), update=False | |
) | |
x_mel_denoised = x_mel_denoised.detach() | |
x_mel_denoised = ( | |
lambd * x_mel_denoised + (1 - lambd) * x_mel_original | |
) | |
else: | |
x_mel_denoised = x_mel_original | |
y_mel = _maybe(self.to_mel)(y) # (b d t) | |
y_mel = _maybe(self.normalizer)(y_mel) | |
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) | |
if lcfm_decoded is None: | |
o = None | |
else: | |
o = self.vocoder(lcfm_decoded, y) | |
return o | |