NaturalSpeech2 / models /tts /valle /valle_trainer.py
yuancwang
init
b725c5a
raw
history blame
No virus
12.8 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from tqdm import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
from optimizer.optimizers import Eve, ScaledAdam
from schedulers.scheduler import NoamScheduler, Eden
from models.tts.valle.valle_dataset import (
VALLEDataset,
VALLECollator,
VariableSampler,
batch_by_size,
)
from models.tts.base import TTSTrainer
from models.tts.valle.valle import VALLE
class VALLETrainer(TTSTrainer):
def __init__(self, args, cfg):
TTSTrainer.__init__(self, args, cfg)
def _build_model(self):
model = VALLE(self.cfg.model)
return model
def _build_dataset(self):
return VALLEDataset, VALLECollator
def _build_optimizer(self):
if self.args.train_stage:
if isinstance(self.model, DistributedDataParallel):
model = self.model.module
else:
model = self.model
model_parameters = model.stage_parameters(self.args.train_stage)
else:
model_parameters = self.model.parameters()
if self.cfg.train.optimizer == "ScaledAdam":
parameters_names = []
if self.args.train_stage != 0:
parameters_names.append(
[
name_param_pair[0]
for name_param_pair in model.stage_named_parameters(
self.args.train_stage
)
]
)
else:
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam(
model_parameters,
lr=self.cfg.train.base_lr,
betas=(0.9, 0.95),
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000,
)
elif self.cfg.train.optimizer == "Eve":
optimizer = Eve(
model_parameters,
lr=self.cfg.train.base_lr,
betas=(0.9, 0.98),
target_rms=0.1,
)
elif self.cfg.train.optimizer == "AdamW":
optimizer = torch.optim.AdamW(
model_parameters,
lr=self.cfg.train.base_lr,
betas=(0.9, 0.95),
weight_decay=1e-2,
eps=1e-8,
)
elif self.cfg.train.optimizer == "Adam":
optimizer = torch.optim.Adam(
model_parameters,
lr=self.cfg.train.base_lr,
betas=(0.9, 0.95),
eps=1e-8,
)
else:
raise NotImplementedError()
return optimizer
def _build_scheduler(self):
if self.cfg.train.scheduler.lower() == "eden":
scheduler = Eden(
self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps
)
elif self.cfg.train.scheduler.lower() == "noam":
scheduler = NoamScheduler(
self.cfg.train.base_lr,
self.optimizer,
self.cfg.model.decoder_dim,
warmup_steps=self.cfg.train.warmup_steps,
)
elif self.cfg.train.scheduler.lower() == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.cfg.train.warmup_steps,
self.optimizer,
eta_min=self.cfg.train.base_lr,
)
else:
raise NotImplementedError(f"{self.cfg.train.scheduler}")
return scheduler
def _train_epoch(self):
r"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].train()
else:
self.model.train()
epoch_sum_loss: float = 0.0
epoch_losses: dict = {}
epoch_step: int = 0
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
# Do training step and BP
with self.accelerator.accumulate(self.model):
total_loss, train_losses = self._train_step(batch)
self.accelerator.backward(total_loss)
self.optimizer.step()
self.optimizer.zero_grad()
self.batch_count += 1
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
for k in range(self.cfg.train.gradient_accumulation_step):
if isinstance(self.scheduler, Eden):
self.scheduler.step_batch(self.step)
else:
self.scheduler.step()
epoch_sum_loss += total_loss.detach().cpu().item()
if isinstance(train_losses, dict):
for key, value in train_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
if isinstance(train_losses, dict):
for key, loss in train_losses.items():
self.accelerator.log(
{"Step/Train {}".format(key): "{:.6f}".format(loss)},
step=self.step,
)
else:
self.accelerator.log(
{"Step/Train Loss": loss},
step=self.step,
)
self.accelerator.log(
{"Step/lr": self.scheduler.get_last_lr()[0]},
step=self.step,
)
# print loss every log_epoch_step steps
# if epoch_step % self.cfg.train.log_epoch_step == 0:
# for key, loss in train_losses.items():
# self.logger.info("Step/Train {}: {:.6f}".format(key, loss))
# print("Step/Train {}: {:.6f}".format(key, loss))
self.step += 1
epoch_step += 1
self.accelerator.wait_for_everyone()
epoch_sum_loss = (
epoch_sum_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
for key in epoch_losses.keys():
epoch_losses[key] = (
epoch_losses[key]
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
return epoch_sum_loss, epoch_losses
def _train_step(self, batch, is_training=True):
text_tokens = batch["phone_seq"].to(self.device)
text_tokens_lens = batch["phone_len"].to(self.device)
assert text_tokens.ndim == 2
audio_features = batch["acoustic_token"].to(self.device)
audio_features_lens = batch["target_len"].to(self.device)
assert audio_features.ndim == 3
with torch.set_grad_enabled(is_training):
loss, losses = self.model(
x=text_tokens,
x_lens=text_tokens_lens,
y=audio_features,
y_lens=audio_features_lens,
train_stage=self.args.train_stage,
)
assert loss.requires_grad == is_training
loss_dict = {}
frames_sum = (audio_features_lens).sum()
avg_loss = loss / frames_sum
loss_dict["loss"] = avg_loss.detach().cpu().item()
for l in losses:
loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item()
return avg_loss, loss_dict
def _valid_step(self, batch):
valid_losses = {}
total_loss = 0
valid_stats = {}
total_loss, valid_losses = self._train_step(
batch=batch,
is_training=False,
)
assert total_loss.requires_grad is False
total_loss = total_loss.detach().cpu().item()
return total_loss, valid_losses, valid_stats
def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--train_stage",
type=int,
default="1",
help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
)
def _build_dataloader(self):
if not self.cfg.train.use_dynamic_batchsize:
return super()._build_dataloader()
if len(self.cfg.dataset) > 1:
raise Exception("use_dynamic_batchsize only supports single dataset now.")
Dataset, Collator = self._build_dataset()
train_dataset = Dataset(
self.cfg, self.cfg.dataset[0], is_valid=False
) # TODO: support use_dynamic_batchsize for more than one datasets.
train_collate = Collator(self.cfg)
batch_sampler = batch_by_size(
train_dataset.num_frame_indices,
train_dataset.get_num_frames,
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
required_batch_size_multiple=self.accelerator.num_processes,
)
np.random.seed(1234)
np.random.shuffle(batch_sampler)
print(batch_sampler[:1])
batches = [
x[self.accelerator.local_process_index :: self.accelerator.num_processes]
for x in batch_sampler
if len(x) % self.accelerator.num_processes == 0
]
train_loader = DataLoader(
train_dataset,
collate_fn=train_collate,
num_workers=self.cfg.train.dataloader.num_worker,
batch_sampler=VariableSampler(
batches, drop_last=False, use_random_sampler=True
),
pin_memory=False,
)
self.accelerator.wait_for_everyone()
valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
valid_collate = Collator(self.cfg)
batch_sampler = batch_by_size(
valid_dataset.num_frame_indices,
valid_dataset.get_num_frames,
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
required_batch_size_multiple=self.accelerator.num_processes,
)
batches = [
x[self.accelerator.local_process_index :: self.accelerator.num_processes]
for x in batch_sampler
if len(x) % self.accelerator.num_processes == 0
]
valid_loader = DataLoader(
valid_dataset,
collate_fn=valid_collate,
num_workers=self.cfg.train.dataloader.num_worker,
batch_sampler=VariableSampler(batches, drop_last=False),
pin_memory=False,
)
self.accelerator.wait_for_everyone()
return train_loader, valid_loader
def _accelerator_prepare(self):
if not self.cfg.train.use_dynamic_batchsize:
(
self.train_dataloader,
self.valid_dataloader,
) = self.accelerator.prepare(
self.train_dataloader,
self.valid_dataloader,
)
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key] = self.accelerator.prepare(self.model[key])
else:
self.model = self.accelerator.prepare(self.model)
if isinstance(self.optimizer, dict):
for key in self.optimizer.keys():
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
else:
self.optimizer = self.accelerator.prepare(self.optimizer)
if isinstance(self.scheduler, dict):
for key in self.scheduler.keys():
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
else:
self.scheduler = self.accelerator.prepare(self.scheduler)