File size: 8,309 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
import itertools
from dataclasses import asdict

from models.model import Vocos
from dataset import VocosDataset
from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from models.loss import feature_loss, generator_loss, discriminator_loss, MultiScaleMelSpectrogramLoss, SingleScaleMelSpectrogramLoss
from config import MelConfig, VocosConfig, TrainConfig
from utils.scheduler import get_cosine_schedule_with_warmup
from utils.load import continue_training

torch.backends.cudnn.benchmark = True
    
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
    
def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig):
    if vocos_config.input_channels != mel_config.n_mels:
        raise ValueError("input_channels and n_mels must be equal.")
    
    if not os.path.exists(train_config.model_save_path):
        print(f'Creating {train_config.model_save_path}')
        os.makedirs(train_config.model_save_path, exist_ok=True)

def train(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    vocos_config = VocosConfig()
    mel_config = MelConfig()
    train_config = TrainConfig()
    
    _init_config(vocos_config, mel_config, train_config)
    
    generator = Vocos(vocos_config, mel_config).to(rank)
    mpd = MultiPeriodDiscriminator().to(rank)
    mrd = MultiResolutionDiscriminator().to(rank)
    loss_fn = MultiScaleMelSpectrogramLoss().to(rank)
    if rank == 0:
        print(f"Generator params: {sum(p.numel() for p in generator.parameters()) / 1e6}")
        print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters()) / 1e6}")
        print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters()) / 1e6}")
    
    generator = DDP(generator, device_ids=[rank])
    mpd = DDP(mpd, device_ids=[rank])
    mrd = DDP(mrd, device_ids=[rank])

    train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config)
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False, persistent_workers=True)
    
    if rank == 0:
        writer = SummaryWriter(train_config.log_dir)

    optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate)
    optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate)
    scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
    scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
    
    # load latest checkpoints if possible
    current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g)

    generator.train()
    mpd.train()
    mrd.train()
    for epoch in range(current_epoch, train_config.num_epochs):  # loop over the train_dataset multiple times
        train_dataloader.sampler.set_epoch(epoch)
        if rank == 0:
            dataloader = tqdm(train_dataloader)
        else:
            dataloader = train_dataloader
            
        for batch_idx, datas in enumerate(dataloader):
            datas = [data.to(rank, non_blocking=True) for data in datas]
            audios, mels = datas
            audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size]
            optimizer_d.zero_grad()
            
            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
            
            # MRD
            y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
            
            loss_disc_all = loss_disc_s + loss_disc_f
            loss_disc_all.backward()
            
            grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000)
            grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000)
            optimizer_d.step()
            scheduler_d.step()
            
            # generator
            optimizer_g.zero_grad()
            loss_mel = loss_fn(audios, audios_fake) * train_config.mel_loss_factor
            
            # MPD loss
            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)

            # MRD loss
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
            loss_gen_all.backward()
            
            grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000)
            optimizer_g.step()
            scheduler_g.step()
            
            if rank == 0 and batch_idx % train_config.log_interval == 0:
                steps = epoch * len(dataloader) + batch_idx
                writer.add_scalar("training/gen_loss_total", loss_gen_all, steps)
                writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
                writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
                writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
                writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
                writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
                writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
                writer.add_scalar("training/mel_loss", loss_mel.item(), steps)
                writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps)
                writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps)
                writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps)
                writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
                writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
            
        if rank == 0:
            torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt'))
            torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt'))
            torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt'))
            torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt'))
            torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt'))
        print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}")

    cleanup()
    
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)