import torch import matplotlib matplotlib.use("Agg") import matplotlib.pylab as plt def get_padding(k, d): return int((k * d - d) / 2) def plot_spectrogram(spectrogram): fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) fig.canvas.draw() plt.close() return fig def save_checkpoint( checkpoint_dir, generator, discriminator, optimizer_generator, optimizer_discriminator, scheduler_generator, scheduler_discriminator, step, loss, best, logger, ): state = { "generator": { "model": generator.state_dict(), "optimizer": optimizer_generator.state_dict(), "scheduler": scheduler_generator.state_dict(), }, "discriminator": { "model": discriminator.state_dict(), "optimizer": optimizer_discriminator.state_dict(), "scheduler": scheduler_discriminator.state_dict(), }, "step": step, "loss": loss, } checkpoint_dir.mkdir(exist_ok=True, parents=True) checkpoint_path = checkpoint_dir / f"model-{step}.pt" torch.save(state, checkpoint_path) if best: best_path = checkpoint_dir / "model-best.pt" torch.save(state, best_path) logger.info(f"Saved checkpoint: {checkpoint_path.stem}") def load_checkpoint( load_path, generator, discriminator, optimizer_generator, optimizer_discriminator, scheduler_generator, scheduler_discriminator, rank, logger, finetune=False, ): logger.info(f"Loading checkpoint from {load_path}") checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) generator.load_state_dict(checkpoint["generator"]["model"]) discriminator.load_state_dict(checkpoint["discriminator"]["model"]) if not finetune: optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"]) scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"]) optimizer_discriminator.load_state_dict( checkpoint["discriminator"]["optimizer"] ) scheduler_discriminator.load_state_dict( checkpoint["discriminator"]["scheduler"] ) return checkpoint["step"], checkpoint["loss"]