from argparse import ArgumentParser import matplotlib.pyplot as plt import pytorch_lightning as pl import torch import wandb from loguru import logger from mmengine import Config from mmengine.optim import OPTIMIZERS from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch.utils.data import DataLoader from fish_diffusion.archs.diffsinger import DiffSinger from fish_diffusion.datasets import DATASETS from fish_diffusion.datasets.repeat import RepeatDataset from fish_diffusion.utils.scheduler import LR_SCHEUDLERS from fish_diffusion.utils.viz import viz_synth_sample from fish_diffusion.vocoders import VOCODERS class FishDiffusion(pl.LightningModule): def __init__(self, config): super().__init__() self.save_hyperparameters() self.model = DiffSinger(config.model) self.config = config # 音频编码器, 将梅尔谱转换为音频 self.vocoder = VOCODERS.build(config.model.vocoder) self.vocoder.freeze() def configure_optimizers(self): self.config.optimizer.params = self.parameters() optimizer = OPTIMIZERS.build(self.config.optimizer) self.config.scheduler.optimizer = optimizer scheduler = LR_SCHEUDLERS.build(self.config.scheduler) return [optimizer], dict(scheduler=scheduler, interval="step") def _step(self, batch, batch_idx, mode): assert batch["pitches"].shape[1] == batch["mels"].shape[1] pitches = batch["pitches"].clone() batch_size = batch["speakers"].shape[0] output = self.model( speakers=batch["speakers"], contents=batch["contents"], src_lens=batch["content_lens"], max_src_len=batch["max_content_len"], mels=batch["mels"], mel_lens=batch["mel_lens"], max_mel_len=batch["max_mel_len"], pitches=batch["pitches"], ) self.log(f"{mode}_loss", output["loss"], batch_size=batch_size, sync_dist=True) if mode != "valid": return output["loss"] x = self.model.diffusion(output["features"]) for idx, (gt_mel, gt_pitch, predict_mel, predict_mel_len) in enumerate( zip(batch["mels"], pitches, x, batch["mel_lens"]) ): image_mels, wav_reconstruction, wav_prediction = viz_synth_sample( gt_mel=gt_mel, gt_pitch=gt_pitch, predict_mel=predict_mel, predict_mel_len=predict_mel_len, vocoder=self.vocoder, return_image=False, ) wav_reconstruction = wav_reconstruction.to(torch.float32).cpu().numpy() wav_prediction = wav_prediction.to(torch.float32).cpu().numpy() # WanDB logger if isinstance(self.logger, WandbLogger): self.logger.experiment.log( { f"reconstruction_mel": wandb.Image(image_mels, caption="mels"), f"wavs": [ wandb.Audio( wav_reconstruction, sample_rate=44100, caption=f"reconstruction (gt)", ), wandb.Audio( wav_prediction, sample_rate=44100, caption=f"prediction", ), ], }, ) # TensorBoard logger if isinstance(self.logger, TensorBoardLogger): self.logger.experiment.add_figure( f"sample-{idx}/mels", image_mels, global_step=self.global_step, ) self.logger.experiment.add_audio( f"sample-{idx}/wavs/gt", wav_reconstruction, self.global_step, sample_rate=44100, ) self.logger.experiment.add_audio( f"sample-{idx}/wavs/prediction", wav_prediction, self.global_step, sample_rate=44100, ) if isinstance(image_mels, plt.Figure): plt.close(image_mels) return output["loss"] def training_step(self, batch, batch_idx): return self._step(batch, batch_idx, mode="train") def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx, mode="valid") if __name__ == "__main__": pl.seed_everything(42, workers=True) parser = ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--resume", type=str, default=None) parser.add_argument( "--tensorboard", action="store_true", default=False, help="Use tensorboard logger, default is wandb.", ) parser.add_argument("--resume-id", type=str, default=None, help="Wandb run id.") parser.add_argument("--entity", type=str, default=None, help="Wandb entity.") parser.add_argument("--name", type=str, default=None, help="Wandb run name.") parser.add_argument( "--pretrained", type=str, default=None, help="Pretrained model." ) parser.add_argument( "--only-train-speaker-embeddings", action="store_true", default=False, help="Only train speaker embeddings.", ) args = parser.parse_args() cfg = Config.fromfile(args.config) model = FishDiffusion(cfg) # We only load the state_dict of the model, not the optimizer. if args.pretrained: state_dict = torch.load(args.pretrained, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] result = model.load_state_dict(state_dict, strict=False) missing_keys = set(result.missing_keys) unexpected_keys = set(result.unexpected_keys) # Make sure incorrect keys are just noise predictor keys. unexpected_keys = unexpected_keys - set( i.replace(".naive_noise_predictor.", ".") for i in missing_keys ) assert len(unexpected_keys) == 0 if args.only_train_speaker_embeddings: for name, param in model.named_parameters(): if "speaker_encoder" not in name: param.requires_grad = False logger.info( "Only train speaker embeddings, all other parameters are frozen." ) logger = ( TensorBoardLogger("logs", name=cfg.model.type) if args.tensorboard else WandbLogger( project=cfg.model.type, save_dir="logs", log_model=True, name=args.name, entity=args.entity, resume="must" if args.resume_id else False, id=args.resume_id, ) ) trainer = pl.Trainer( logger=logger, **cfg.trainer, ) train_dataset = DATASETS.build(cfg.dataset.train) train_loader = DataLoader( train_dataset, collate_fn=train_dataset.collate_fn, **cfg.dataloader.train, ) valid_dataset = DATASETS.build(cfg.dataset.valid) valid_dataset = RepeatDataset( valid_dataset, repeat=trainer.num_devices, collate_fn=valid_dataset.collate_fn ) valid_loader = DataLoader( valid_dataset, collate_fn=valid_dataset.collate_fn, **cfg.dataloader.valid, ) trainer.fit(model, train_loader, valid_loader, ckpt_path=args.resume)