# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- from dataclasses import dataclass, field import torch from fairseq import metrics, utils from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss from omegaconf import II from typing import Any @dataclass class TexttoSpeechLossConfig(FairseqDataclass): use_masking: bool = field( default=True, metadata={"help": "Whether to use masking in calculation of loss"}, ) use_weighted_masking: bool = field( default=False, metadata={"help": "Whether to use weighted masking in calculation of loss"}, ) loss_type: str = field( default="L1", metadata={"help": "How to calc loss"}, ) bce_pos_weight: float = field( default=5.0, metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"}, ) bce_loss_lambda: float = field( default=1.0, metadata={"help": "Lambda in bce loss"}, ) use_guided_attn_loss: bool = field( default=False, metadata={"help": "Whether to use guided attention loss"}, ) guided_attn_loss_sigma: float = field( default=0.4, metadata={"help": "Sigma in guided attention loss"}, ) guided_attn_loss_lambda: float = field( default=10.0, metadata={"help": "Lambda in guided attention loss"}, ) num_layers_applied_guided_attn: int = field( default=2, metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."}, ) num_heads_applied_guided_attn: int = field( default=2, metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."}, ) modules_applied_guided_attn: Any = field( default=("encoder-decoder",), metadata={"help": "Module name list to be applied guided attention loss"}, ) sentence_avg: bool = II("optimization.sentence_avg") class TexttoSpeechLoss(FairseqCriterion): def __init__( self, task, sentence_avg, use_masking=True, use_weighted_masking=False, loss_type="L1", bce_pos_weight=5.0, bce_loss_lambda=1.0, use_guided_attn_loss=False, guided_attn_loss_sigma=0.4, guided_attn_loss_lambda=1.0, num_layers_applied_guided_attn=2, num_heads_applied_guided_attn=2, modules_applied_guided_attn=["encoder-decoder"], ): super().__init__(task) self.sentence_avg = sentence_avg self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking self.loss_type = loss_type self.bce_pos_weight = bce_pos_weight self.bce_loss_lambda = bce_loss_lambda self.use_guided_attn_loss = use_guided_attn_loss self.guided_attn_loss_sigma = guided_attn_loss_sigma self.guided_attn_loss_lambda = guided_attn_loss_lambda # define loss function self.criterion = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) def forward(self, model, sample): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample) # sample_size = ( # sample["target"].size(0) if self.sentence_avg else sample["nframes"] # ) sample_size = 1 logging_output = { "loss": loss.item(), "l1_loss": l1_loss.item(), "l2_loss": l2_loss.item(), "bce_loss": bce_loss.item(), "sample_size": 1, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), } if enc_dec_attn_loss is not None: logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item() if hasattr(model, 'text_encoder_prenet'): logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item() logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() elif hasattr(model, "speech_encoder_prenet"): logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() else: if 'task' not in sample: logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item() logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item() return loss, sample_size, logging_output def compute_loss(self, model, net_output, sample): before_outs, after_outs, logits, attn = net_output labels = sample["labels"] ys = sample["dec_target"] olens = sample["dec_target_lengths"] ilens = sample["src_lengths"] # modifiy mod part of groundtruth if model.reduction_factor > 1: olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens]) olens = olens.new([olen - olen % model.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1 # labels[:, -1] = 1.0 else: olens_in = olens # caluculate loss values l1_loss, l2_loss, bce_loss = self.criterion( after_outs, before_outs, logits, ys, labels, olens ) # l1_loss = l1_loss / ys.size(2) # l2_loss = l2_loss / ys.size(2) if self.loss_type == "L1": loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss elif self.loss_type == "L2": loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss else: raise ValueError("unknown --loss-type " + self.loss_type) # calculate guided attention loss enc_dec_attn_loss = None if self.use_guided_attn_loss: # calculate the input lengths of encoder, which is determined by encoder prenet if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1: ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens]) else: ilens_in = ilens # work for speech to speech model's input if "task_name" in sample and sample["task_name"] == "s2s": m = None if hasattr(model, 'encoder_prenet'): m = model.encoder_prenet elif hasattr(model, 'speech_encoder_prenet'): m = model.speech_encoder_prenet if m is not None and isinstance(m, SpeechEncoderPrenet): ilens_in = m.get_src_lengths(ilens_in) # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn] att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in) loss = loss + enc_dec_attn_loss return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss @classmethod def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs) l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs) bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs) sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) metrics.log_scalar( "loss", loss_sum / sample_size, sample_size, 1, round=5 ) encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs) decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs) ngpu = sum(log.get("ngpu", 0) for log in logging_outputs) metrics.log_scalar( "l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 ) metrics.log_scalar( "decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 ) if "enc_dec_attn_loss" in logging_outputs[0]: enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs) metrics.log_scalar( "enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 ) @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return True class Tacotron2Loss(torch.nn.Module): """Loss function module for Tacotron2.""" def __init__( self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0 ): """Initialize Tactoron2 loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token. """ super(Tacotron2Loss, self).__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions # reduction = "none" if self.use_weighted_masking else "sum" reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = torch.nn.L1Loss(reduction=reduction) self.mse_criterion = torch.nn.MSELoss(reduction=reduction) self.bce_criterion = torch.nn.BCEWithLogitsLoss( reduction=reduction, pos_weight=torch.tensor(bce_pos_weight) ) # NOTE(kan-bayashi): register pre hook function for the compatibility self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) def forward(self, after_outs, before_outs, logits, ys, labels, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). logits (Tensor): Batch of stop logits (B, Lmax). ys (Tensor): Batch of padded target features (B, Lmax, odim). labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ # make mask and apply it if self.use_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) ys = ys.masked_select(masks) after_outs = after_outs.masked_select(masks) before_outs = before_outs.masked_select(masks) labels = labels.masked_select(masks[:, :, 0]) logits = logits.masked_select(masks[:, :, 0]) # calculate loss l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( before_outs, ys ) bce_loss = self.bce_criterion(logits, labels) # make weighted mask and apply it if self.use_weighted_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) weights = masks.float() / masks.sum(dim=1, keepdim=True).float() out_weights = weights.div(ys.size(0) * ys.size(2)) logit_weights = weights.div(ys.size(0)) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() bce_loss = ( bce_loss.mul(logit_weights.squeeze(-1)) .masked_select(masks.squeeze(-1)) .sum() ) return l1_loss, mse_loss, bce_loss def _load_state_dict_pre_hook( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): """Apply pre hook fucntion before loading state dict. From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but old models do not include it and as a result, it causes missing key error when loading old model parameter. This function solve the issue by adding param in state dict before loading as a pre hook function of the `load_state_dict` method. """ key = prefix + "bce_criterion.pos_weight" if key not in state_dict: state_dict[key] = self.bce_criterion.pos_weight class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): """Guided attention loss function module for multi head attention. Args: sigma (float, optional): Standard deviation to control how close attention to a diagonal. alpha (float, optional): Scaling coefficient (lambda). reset_always (bool, optional): Whether to always reset masks. """ def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Args: att_ws (Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in). ilens (LongTensor): Batch of input lenghts (B,). olens (LongTensor): Batch of output lenghts (B,). Returns: Tensor: Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = ( self._make_guided_attention_masks(ilens, olens) .to(att_ws.device) .unsqueeze(1) ) if self.masks is None: self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) losses = self.guided_attn_masks * att_ws loss = torch.mean(losses.masked_select(self.masks)) if self.reset_always: self._reset_masks() return self.alpha * loss def _make_guided_attention_masks(self, ilens, olens): n_batches = len(ilens) max_ilen = max(ilens) max_olen = max(olens) guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( ilen, olen, self.sigma ) return guided_attn_masks @staticmethod def _make_guided_attention_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device)) grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp( -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)) ) @staticmethod def _make_masks(ilens, olens): in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in) out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)