import torch import numpy as np from torch.utils.data import ConcatDataset, DataLoader from models.tts.naturalspeech2.base_trainer import TTSTrainer from models.base.base_trainer import BaseTrainer from models.base.base_sampler import VariableSampler from models.tts.naturalspeech2.ns2_dataset import NS2Dataset, NS2Collator, batch_by_size from models.tts.naturalspeech2.ns2_loss import ( log_pitch_loss, log_dur_loss, diff_loss, diff_ce_loss, ) from torch.utils.data.sampler import BatchSampler, SequentialSampler from models.tts.naturalspeech2.ns2 import NaturalSpeech2 from torch.optim import Adam, AdamW from torch.nn import MSELoss, L1Loss import torch.nn.functional as F from diffusers import get_scheduler class NS2Trainer(TTSTrainer): def __init__(self, args, cfg): TTSTrainer.__init__(self, args, cfg) def _build_model(self): model = NaturalSpeech2(cfg=self.cfg.model) return model def _build_dataset(self): return NS2Dataset, NS2Collator def _build_dataloader(self): if self.cfg.train.use_dynamic_batchsize: print("Use Dynamic Batchsize......") Dataset, Collator = self._build_dataset() train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False) train_collate = Collator(self.cfg) batch_sampler = batch_by_size( train_dataset.num_frame_indices, train_dataset.get_num_frames, max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, required_batch_size_multiple=self.accelerator.num_processes, ) np.random.seed(980205) np.random.shuffle(batch_sampler) print(batch_sampler[:1]) batches = [ x[ self.accelerator.local_process_index :: self.accelerator.num_processes ] for x in batch_sampler if len(x) % self.accelerator.num_processes == 0 ] train_loader = DataLoader( train_dataset, collate_fn=train_collate, num_workers=self.cfg.train.dataloader.num_worker, batch_sampler=VariableSampler( batches, drop_last=False, use_random_sampler=True ), pin_memory=self.cfg.train.dataloader.pin_memory, ) self.accelerator.wait_for_everyone() valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True) valid_collate = Collator(self.cfg) batch_sampler = batch_by_size( valid_dataset.num_frame_indices, valid_dataset.get_num_frames, max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, required_batch_size_multiple=self.accelerator.num_processes, ) batches = [ x[ self.accelerator.local_process_index :: self.accelerator.num_processes ] for x in batch_sampler if len(x) % self.accelerator.num_processes == 0 ] valid_loader = DataLoader( valid_dataset, collate_fn=valid_collate, num_workers=self.cfg.train.dataloader.num_worker, batch_sampler=VariableSampler(batches, drop_last=False), pin_memory=self.cfg.train.dataloader.pin_memory, ) self.accelerator.wait_for_everyone() else: print("Use Normal Batchsize......") Dataset, Collator = self._build_dataset() train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False) train_collate = Collator(self.cfg) train_loader = DataLoader( train_dataset, shuffle=True, collate_fn=train_collate, batch_size=self.cfg.train.batch_size, num_workers=self.cfg.train.dataloader.num_worker, pin_memory=self.cfg.train.dataloader.pin_memory, ) valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True) valid_collate = Collator(self.cfg) valid_loader = DataLoader( valid_dataset, shuffle=True, collate_fn=valid_collate, batch_size=self.cfg.train.batch_size, num_workers=self.cfg.train.dataloader.num_worker, pin_memory=self.cfg.train.dataloader.pin_memory, ) self.accelerator.wait_for_everyone() return train_loader, valid_loader def _build_optimizer(self): optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, self.model.parameters()), **self.cfg.train.adam ) return optimizer def _build_scheduler(self): lr_scheduler = get_scheduler( self.cfg.train.lr_scheduler, optimizer=self.optimizer, num_warmup_steps=self.cfg.train.lr_warmup_steps, num_training_steps=self.cfg.train.num_train_steps, ) return lr_scheduler def _build_criterion(self): criterion = torch.nn.L1Loss(reduction="mean") return criterion def write_summary(self, losses, stats): for key, value in losses.items(): self.sw.add_scalar(key, value, self.step) def write_valid_summary(self, losses, stats): for key, value in losses.items(): self.sw.add_scalar(key, value, self.step) def get_state_dict(self): state_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "step": self.step, "epoch": self.epoch, "batch_size": self.cfg.train.batch_size, } return state_dict def load_model(self, checkpoint): self.step = checkpoint["step"] self.epoch = checkpoint["epoch"] self.model.load_state_dict(checkpoint["model"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.scheduler.load_state_dict(checkpoint["scheduler"]) def _train_step(self, batch): train_losses = {} total_loss = 0 train_stats = {} code = batch["code"] # (B, 16, T) pitch = batch["pitch"] # (B, T) duration = batch["duration"] # (B, N) phone_id = batch["phone_id"] # (B, N) ref_code = batch["ref_code"] # (B, 16, T') phone_mask = batch["phone_mask"] # (B, N) mask = batch["mask"] # (B, T) ref_mask = batch["ref_mask"] # (B, T') diff_out, prior_out = self.model( code=code, pitch=pitch, duration=duration, phone_id=phone_id, ref_code=ref_code, phone_mask=phone_mask, mask=mask, ref_mask=ref_mask, ) # pitch loss pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask) total_loss += pitch_loss train_losses["pitch_loss"] = pitch_loss # duration loss dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask) total_loss += dur_loss train_losses["dur_loss"] = dur_loss x0 = self.model.module.code_to_latent(code) if self.cfg.model.diffusion.diffusion_type == "diffusion": # diff loss x0 diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) total_loss += diff_loss_x0 train_losses["diff_loss_x0"] = diff_loss_x0 # diff loss noise diff_loss_noise = diff_loss( diff_out["noise_pred"], diff_out["noise"], mask=mask ) total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda train_losses["diff_loss_noise"] = diff_loss_noise elif self.cfg.model.diffusion.diffusion_type == "flow": # diff flow matching loss flow_gt = diff_out["noise"] - x0 diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask) total_loss += diff_loss_flow train_losses["diff_loss_flow"] = diff_loss_flow # diff loss ce # (nq, B, T); (nq, B, T, 1024) if self.cfg.train.diff_ce_loss_lambda > 0: pred_indices, pred_dist = self.model.module.latent_to_code( diff_out["x0_pred"], nq=code.shape[1] ) gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1]) diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask) total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda train_losses["diff_loss_ce"] = diff_loss_ce self.optimizer.zero_grad() # total_loss.backward() self.accelerator.backward(total_loss) if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_( filter(lambda p: p.requires_grad, self.model.parameters()), 0.5 ) self.optimizer.step() self.scheduler.step() for item in train_losses: train_losses[item] = train_losses[item].item() if self.cfg.train.diff_ce_loss_lambda > 0: pred_indices_list = pred_indices.long().detach().cpu().numpy() gt_indices_list = gt_indices.long().detach().cpu().numpy() mask_list = batch["mask"].detach().cpu().numpy() for i in range(pred_indices_list.shape[0]): pred_acc = np.sum( (pred_indices_list[i] == gt_indices_list[i]) * mask_list ) / np.sum(mask_list) train_losses["pred_acc_{}".format(str(i))] = pred_acc train_losses["batch_size"] = code.shape[0] train_losses["max_frame_nums"] = np.max( batch["frame_nums"].detach().cpu().numpy() ) return (total_loss.item(), train_losses, train_stats) @torch.inference_mode() def _valid_step(self, batch): valid_losses = {} total_loss = 0 valid_stats = {} code = batch["code"] # (B, 16, T) pitch = batch["pitch"] # (B, T) duration = batch["duration"] # (B, N) phone_id = batch["phone_id"] # (B, N) ref_code = batch["ref_code"] # (B, 16, T') phone_mask = batch["phone_mask"] # (B, N) mask = batch["mask"] # (B, T) ref_mask = batch["ref_mask"] # (B, T') diff_out, prior_out = self.model( code=code, pitch=pitch, duration=duration, phone_id=phone_id, ref_code=ref_code, phone_mask=phone_mask, mask=mask, ref_mask=ref_mask, ) # pitch loss pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask) total_loss += pitch_loss valid_losses["pitch_loss"] = pitch_loss # duration loss dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask) total_loss += dur_loss valid_losses["dur_loss"] = dur_loss x0 = self.model.module.code_to_latent(code) if self.cfg.model.diffusion.diffusion_type == "diffusion": # diff loss x0 diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask) total_loss += diff_loss_x0 valid_losses["diff_loss_x0"] = diff_loss_x0 # diff loss noise diff_loss_noise = diff_loss( diff_out["noise_pred"], diff_out["noise"], mask=mask ) total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda valid_losses["diff_loss_noise"] = diff_loss_noise elif self.cfg.model.diffusion.diffusion_type == "flow": # diff flow matching loss flow_gt = diff_out["noise"] - x0 diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask) total_loss += diff_loss_flow valid_losses["diff_loss_flow"] = diff_loss_flow # diff loss ce # (nq, B, T); (nq, B, T, 1024) if self.cfg.train.diff_ce_loss_lambda > 0: pred_indices, pred_dist = self.model.module.latent_to_code( diff_out["x0_pred"], nq=code.shape[1] ) gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1]) diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask) total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda valid_losses["diff_loss_ce"] = diff_loss_ce for item in valid_losses: valid_losses[item] = valid_losses[item].item() if self.cfg.train.diff_ce_loss_lambda > 0: pred_indices_list = pred_indices.long().detach().cpu().numpy() gt_indices_list = gt_indices.long().detach().cpu().numpy() mask_list = batch["mask"].detach().cpu().numpy() for i in range(pred_indices_list.shape[0]): pred_acc = np.sum( (pred_indices_list[i] == gt_indices_list[i]) * mask_list ) / np.sum(mask_list) valid_losses["pred_acc_{}".format(str(i))] = pred_acc return (total_loss.item(), valid_losses, valid_stats) # def _train_epoch(self): # ...