Spaces:
Runtime error
Runtime error
File size: 8,309 Bytes
3dd84f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import itertools
from dataclasses import asdict
from models.model import Vocos
from dataset import VocosDataset
from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from models.loss import feature_loss, generator_loss, discriminator_loss, MultiScaleMelSpectrogramLoss, SingleScaleMelSpectrogramLoss
from config import MelConfig, VocosConfig, TrainConfig
from utils.scheduler import get_cosine_schedule_with_warmup
from utils.load import continue_training
torch.backends.cudnn.benchmark = True
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig):
if vocos_config.input_channels != mel_config.n_mels:
raise ValueError("input_channels and n_mels must be equal.")
if not os.path.exists(train_config.model_save_path):
print(f'Creating {train_config.model_save_path}')
os.makedirs(train_config.model_save_path, exist_ok=True)
def train(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
vocos_config = VocosConfig()
mel_config = MelConfig()
train_config = TrainConfig()
_init_config(vocos_config, mel_config, train_config)
generator = Vocos(vocos_config, mel_config).to(rank)
mpd = MultiPeriodDiscriminator().to(rank)
mrd = MultiResolutionDiscriminator().to(rank)
loss_fn = MultiScaleMelSpectrogramLoss().to(rank)
if rank == 0:
print(f"Generator params: {sum(p.numel() for p in generator.parameters()) / 1e6}")
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters()) / 1e6}")
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters()) / 1e6}")
generator = DDP(generator, device_ids=[rank])
mpd = DDP(mpd, device_ids=[rank])
mrd = DDP(mrd, device_ids=[rank])
train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False, persistent_workers=True)
if rank == 0:
writer = SummaryWriter(train_config.log_dir)
optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate)
optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate)
scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
# load latest checkpoints if possible
current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g)
generator.train()
mpd.train()
mrd.train()
for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times
train_dataloader.sampler.set_epoch(epoch)
if rank == 0:
dataloader = tqdm(train_dataloader)
else:
dataloader = train_dataloader
for batch_idx, datas in enumerate(dataloader):
datas = [data.to(rank, non_blocking=True) for data in datas]
audios, mels = datas
audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size]
optimizer_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000)
optimizer_d.step()
scheduler_d.step()
# generator
optimizer_g.zero_grad()
loss_mel = loss_fn(audios, audios_fake) * train_config.mel_loss_factor
# MPD loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
# MRD loss
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000)
optimizer_g.step()
scheduler_g.step()
if rank == 0 and batch_idx % train_config.log_interval == 0:
steps = epoch * len(dataloader) + batch_idx
writer.add_scalar("training/gen_loss_total", loss_gen_all, steps)
writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
writer.add_scalar("training/mel_loss", loss_mel.item(), steps)
writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps)
writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps)
writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps)
writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
if rank == 0:
torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt'))
torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt'))
torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt'))
torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt'))
torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt'))
print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}")
cleanup()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size) |