Wendyellé Abubakrh Alban NYANTUDRE
deleted parent dir resemble-enhance
689d78f
import argparse
import random
from functools import partial
from pathlib import Path
import soundfile
import torch
from deepspeed import DeepSpeedConfig
from torch import Tensor
from tqdm import tqdm
from ..data import create_dataloaders, mix_fg_bg
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
from ..utils.distributed import is_local_leader
from .enhancer import Enhancer
from .hparams import HParams
from .univnet.discriminator import Discriminator
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
model = Enhancer(hp)
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
if training:
engine.load_checkpoint()
else:
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
return engine
def load_D(run_dir: Path, hp: HParams):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
model = Discriminator(hp)
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
engine.load_checkpoint()
return engine
def save_wav(path: Path, wav: Tensor, rate: int):
wav = wav.detach().cpu().numpy()
soundfile.write(path, wav, samplerate=rate)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("run_dir", type=Path)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
setup_logging(args.run_dir)
hp = HParams.load(args.run_dir, yaml=args.yaml)
if is_local_leader():
hp.save_if_not_exists(args.run_dir)
hp.print()
train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
def feed_G(engine: Engine, batch: dict[str, Tensor]):
if hp.lcfm_training_mode == "ae":
pred = engine(batch["fg_wavs"], batch["fg_wavs"])
elif hp.lcfm_training_mode == "cfm":
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
else:
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
losses = engine.gather_attribute("losses")
return pred, losses
def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
if batch is None:
losses = engine(fake=fake)
else:
losses = engine(fake=fake, real=batch["fg_wavs"])
return losses
@torch.no_grad()
def eval_fn(engine: Engine, eval_dir, n_saved=10):
assert isinstance(hp, HParams)
model = engine.module
model.eval()
step = engine.global_step
for i, batch in enumerate(tqdm(val_dl), 1):
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
fg_wavs = batch["fg_wavs"] # 1 t
if hp.lcfm_training_mode == "ae":
in_dwavs = fg_wavs
elif hp.lcfm_training_mode == "cfm":
in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
else:
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
pred_fg_wavs = model(in_dwavs) # 1 t
in_mels = model.to_mel(in_dwavs) # 1 c t
fg_mels = model.to_mel(fg_wavs) # 1 c t
pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
rate = model.hp.wav_rate
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
save_mels(
get_path(".png"),
cond_mel=in_mels[0].cpu().numpy(),
pred_mel=pred_fg_mels[0].cpu().numpy(),
targ_mel=fg_mels[0].cpu().numpy(),
)
if i >= n_saved:
break
train_loop = TrainLoop(
run_dir=args.run_dir,
train_dl=train_dl,
load_G=partial(load_G, hp=hp),
load_D=partial(load_D, hp=hp),
device=args.device,
feed_G=feed_G,
feed_D=feed_D,
eval_fn=eval_fn,
gan_training_start_step=hp.gan_training_start_step,
)
train_loop.run(max_steps=hp.max_steps)
if __name__ == "__main__":
main()