yuancwang
init
b725c5a
raw
history blame
No virus
5.12 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from tqdm import tqdm
from models.tts.base import TTSTrainer
from models.tts.fastspeech2.fs2 import FastSpeech2, FastSpeech2Loss
from models.tts.fastspeech2.fs2_dataset import FS2Dataset, FS2Collator
from optimizer.optimizers import NoamLR
class FastSpeech2Trainer(TTSTrainer):
def __init__(self, args, cfg):
TTSTrainer.__init__(self, args, cfg)
self.cfg = cfg
def _build_dataset(self):
return FS2Dataset, FS2Collator
def __build_scheduler(self):
return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
def _write_summary(self, losses, stats):
for key, value in losses.items():
self.sw.add_scalar("train/" + key, value, self.step)
lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
self.sw.add_scalar("learning_rate", lr, self.step)
def _write_valid_summary(self, losses, stats):
for key, value in losses.items():
self.sw.add_scalar("val/" + key, value, self.step)
def _build_criterion(self):
return FastSpeech2Loss(self.cfg)
def get_state_dict(self):
state_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"step": self.step,
"epoch": self.epoch,
"batch_size": self.cfg.train.batch_size,
}
return state_dict
def _build_optimizer(self):
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
return optimizer
def _build_scheduler(self):
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
return scheduler
def _build_model(self):
self.model = FastSpeech2(self.cfg)
return self.model
def _train_epoch(self):
r"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
self.model.train()
epoch_sum_loss: float = 0.0
epoch_step: int = 0
epoch_losses: dict = {}
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
# Do training step and BP
with self.accelerator.accumulate(self.model):
loss, train_losses = self._train_step(batch)
self.accelerator.backward(loss)
grad_clip_thresh = self.cfg.train.grad_clip_thresh
nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_thresh)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.batch_count += 1
# Update info for each step
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
epoch_sum_loss += loss
for key, value in train_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
self.accelerator.log(
{
"Step/Train Loss": loss,
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
},
step=self.step,
)
self.step += 1
epoch_step += 1
self.accelerator.wait_for_everyone()
epoch_sum_loss = (
epoch_sum_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
for key in epoch_losses.keys():
epoch_losses[key] = (
epoch_losses[key]
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
return epoch_sum_loss, epoch_losses
def _train_step(self, data):
train_losses = {}
total_loss = 0
train_stats = {}
preds = self.model(data)
train_losses = self.criterion(data, preds)
total_loss = train_losses["loss"]
for key, value in train_losses.items():
train_losses[key] = value.item()
return total_loss, train_losses
@torch.no_grad()
def _valid_step(self, data):
valid_loss = {}
total_valid_loss = 0
valid_stats = {}
preds = self.model(data)
valid_losses = self.criterion(data, preds)
total_valid_loss = valid_losses["loss"]
for key, value in valid_losses.items():
valid_losses[key] = value.item()
return total_valid_loss, valid_losses, valid_stats