|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from diffusers import DDPMScheduler |
|
|
|
from models.svc.base import SVCTrainer |
|
from modules.encoder.condition_encoder import ConditionEncoder |
|
from .diffusion_wrapper import DiffusionWrapper |
|
|
|
|
|
class DiffusionTrainer(SVCTrainer): |
|
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and |
|
implements ``_build_model`` and ``_forward_step`` methods. |
|
""" |
|
|
|
def __init__(self, args=None, cfg=None): |
|
SVCTrainer.__init__(self, args, cfg) |
|
|
|
|
|
self.noise_scheduler = DDPMScheduler( |
|
**self.cfg.model.diffusion.scheduler_settings, |
|
) |
|
self.diffusion_timesteps = ( |
|
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps |
|
) |
|
|
|
|
|
def _build_model(self): |
|
r"""Build the model for training. This function is called in ``__init__`` function.""" |
|
|
|
|
|
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min |
|
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max |
|
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) |
|
self.acoustic_mapper = DiffusionWrapper(self.cfg) |
|
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) |
|
|
|
num_of_params_encoder = self.count_parameters(self.condition_encoder) |
|
num_of_params_am = self.count_parameters(self.acoustic_mapper) |
|
num_of_params = num_of_params_encoder + num_of_params_am |
|
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format( |
|
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6 |
|
) |
|
self.logger.info(log) |
|
|
|
return model |
|
|
|
def count_parameters(self, model): |
|
model_param = 0.0 |
|
if isinstance(model, dict): |
|
for key, value in model.items(): |
|
model_param += sum(p.numel() for p in model[key].parameters()) |
|
else: |
|
model_param = sum(p.numel() for p in model.parameters()) |
|
return model_param |
|
|
|
def _check_nan(self, batch, loss, y_pred, y_gt): |
|
if torch.any(torch.isnan(loss)): |
|
for k, v in batch.items(): |
|
self.logger.info(k) |
|
self.logger.info(v) |
|
|
|
super()._check_nan(loss, y_pred, y_gt) |
|
|
|
def _forward_step(self, batch): |
|
r"""Forward step for training and inference. This function is called |
|
in ``_train_step`` & ``_test_step`` function. |
|
""" |
|
device = self.accelerator.device |
|
|
|
if self.online_features_extraction: |
|
|
|
batch = self._extract_svc_features(batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
mel_input = batch["mel"] |
|
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32) |
|
batch_size = mel_input.size(0) |
|
timesteps = torch.randint( |
|
0, |
|
self.diffusion_timesteps, |
|
(batch_size,), |
|
device=device, |
|
dtype=torch.long, |
|
) |
|
|
|
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps) |
|
conditioner = self.condition_encoder(batch) |
|
|
|
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner) |
|
|
|
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"]) |
|
self._check_nan(batch, loss, y_pred, noise) |
|
|
|
return loss |
|
|