Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
from dataclasses import dataclass, field | |
import torch | |
from fairseq import metrics, utils | |
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
from fairseq.dataclass import FairseqDataclass | |
from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet | |
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss | |
from omegaconf import II | |
from typing import Any | |
class TexttoSpeechLossConfig(FairseqDataclass): | |
use_masking: bool = field( | |
default=True, | |
metadata={"help": "Whether to use masking in calculation of loss"}, | |
) | |
use_weighted_masking: bool = field( | |
default=False, | |
metadata={"help": "Whether to use weighted masking in calculation of loss"}, | |
) | |
loss_type: str = field( | |
default="L1", | |
metadata={"help": "How to calc loss"}, | |
) | |
bce_pos_weight: float = field( | |
default=5.0, | |
metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"}, | |
) | |
bce_loss_lambda: float = field( | |
default=1.0, | |
metadata={"help": "Lambda in bce loss"}, | |
) | |
use_guided_attn_loss: bool = field( | |
default=False, | |
metadata={"help": "Whether to use guided attention loss"}, | |
) | |
guided_attn_loss_sigma: float = field( | |
default=0.4, | |
metadata={"help": "Sigma in guided attention loss"}, | |
) | |
guided_attn_loss_lambda: float = field( | |
default=10.0, | |
metadata={"help": "Lambda in guided attention loss"}, | |
) | |
num_layers_applied_guided_attn: int = field( | |
default=2, | |
metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."}, | |
) | |
num_heads_applied_guided_attn: int = field( | |
default=2, | |
metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."}, | |
) | |
modules_applied_guided_attn: Any = field( | |
default=("encoder-decoder",), | |
metadata={"help": "Module name list to be applied guided attention loss"}, | |
) | |
sentence_avg: bool = II("optimization.sentence_avg") | |
class TexttoSpeechLoss(FairseqCriterion): | |
def __init__( | |
self, | |
task, | |
sentence_avg, | |
use_masking=True, | |
use_weighted_masking=False, | |
loss_type="L1", | |
bce_pos_weight=5.0, | |
bce_loss_lambda=1.0, | |
use_guided_attn_loss=False, | |
guided_attn_loss_sigma=0.4, | |
guided_attn_loss_lambda=1.0, | |
num_layers_applied_guided_attn=2, | |
num_heads_applied_guided_attn=2, | |
modules_applied_guided_attn=["encoder-decoder"], | |
): | |
super().__init__(task) | |
self.sentence_avg = sentence_avg | |
self.use_masking = use_masking | |
self.use_weighted_masking = use_weighted_masking | |
self.loss_type = loss_type | |
self.bce_pos_weight = bce_pos_weight | |
self.bce_loss_lambda = bce_loss_lambda | |
self.use_guided_attn_loss = use_guided_attn_loss | |
self.guided_attn_loss_sigma = guided_attn_loss_sigma | |
self.guided_attn_loss_lambda = guided_attn_loss_lambda | |
# define loss function | |
self.criterion = Tacotron2Loss( | |
use_masking=use_masking, | |
use_weighted_masking=use_weighted_masking, | |
bce_pos_weight=bce_pos_weight, | |
) | |
if self.use_guided_attn_loss: | |
self.num_layers_applied_guided_attn = num_layers_applied_guided_attn | |
self.num_heads_applied_guided_attn = num_heads_applied_guided_attn | |
self.modules_applied_guided_attn = modules_applied_guided_attn | |
if self.use_guided_attn_loss: | |
self.attn_criterion = GuidedMultiHeadAttentionLoss( | |
sigma=guided_attn_loss_sigma, | |
alpha=guided_attn_loss_lambda, | |
) | |
def forward(self, model, sample): | |
"""Compute the loss for the given sample. | |
Returns a tuple with three elements: | |
1) the loss | |
2) the sample size, which is used as the denominator for the gradient | |
3) logging outputs to display while training | |
""" | |
net_output = model(**sample["net_input"]) | |
loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample) | |
# sample_size = ( | |
# sample["target"].size(0) if self.sentence_avg else sample["nframes"] | |
# ) | |
sample_size = 1 | |
logging_output = { | |
"loss": loss.item(), | |
"l1_loss": l1_loss.item(), | |
"l2_loss": l2_loss.item(), | |
"bce_loss": bce_loss.item(), | |
"sample_size": 1, | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["target"].size(0), | |
} | |
if enc_dec_attn_loss is not None: | |
logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item() | |
if hasattr(model, 'text_encoder_prenet'): | |
logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item() | |
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() | |
elif hasattr(model, "speech_encoder_prenet"): | |
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() | |
else: | |
if 'task' not in sample: | |
logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item() | |
logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item() | |
return loss, sample_size, logging_output | |
def compute_loss(self, model, net_output, sample): | |
before_outs, after_outs, logits, attn = net_output | |
labels = sample["labels"] | |
ys = sample["dec_target"] | |
olens = sample["dec_target_lengths"] | |
ilens = sample["src_lengths"] | |
# modifiy mod part of groundtruth | |
if model.reduction_factor > 1: | |
olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens]) | |
olens = olens.new([olen - olen % model.reduction_factor for olen in olens]) | |
max_olen = max(olens) | |
ys = ys[:, :max_olen] | |
labels = labels[:, :max_olen] | |
labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1 | |
# labels[:, -1] = 1.0 | |
else: | |
olens_in = olens | |
# caluculate loss values | |
l1_loss, l2_loss, bce_loss = self.criterion( | |
after_outs, before_outs, logits, ys, labels, olens | |
) | |
# l1_loss = l1_loss / ys.size(2) | |
# l2_loss = l2_loss / ys.size(2) | |
if self.loss_type == "L1": | |
loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss | |
elif self.loss_type == "L2": | |
loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss | |
elif self.loss_type == "L1+L2": | |
loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss | |
else: | |
raise ValueError("unknown --loss-type " + self.loss_type) | |
# calculate guided attention loss | |
enc_dec_attn_loss = None | |
if self.use_guided_attn_loss: | |
# calculate the input lengths of encoder, which is determined by encoder prenet | |
if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1: | |
ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens]) | |
else: | |
ilens_in = ilens | |
# work for speech to speech model's input | |
if "task_name" in sample and sample["task_name"] == "s2s": | |
m = None | |
if hasattr(model, 'encoder_prenet'): | |
m = model.encoder_prenet | |
elif hasattr(model, 'speech_encoder_prenet'): | |
m = model.speech_encoder_prenet | |
if m is not None and isinstance(m, SpeechEncoderPrenet): | |
ilens_in = m.get_src_lengths(ilens_in) | |
# calculate for encoder-decoder | |
if "encoder-decoder" in self.modules_applied_guided_attn: | |
attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn] | |
att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in) | |
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in) | |
loss = loss + enc_dec_attn_loss | |
return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss | |
def reduce_metrics(cls, logging_outputs) -> None: | |
"""Aggregate logging outputs from data parallel training.""" | |
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs) | |
l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs) | |
bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) | |
metrics.log_scalar( | |
"loss", loss_sum / sample_size, sample_size, 1, round=5 | |
) | |
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs) | |
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs) | |
ngpu = sum(log.get("ngpu", 0) for log in logging_outputs) | |
metrics.log_scalar( | |
"l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 | |
) | |
metrics.log_scalar( | |
"decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 | |
) | |
if "enc_dec_attn_loss" in logging_outputs[0]: | |
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs) | |
metrics.log_scalar( | |
"enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 | |
) | |
def logging_outputs_can_be_summed() -> bool: | |
""" | |
Whether the logging outputs returned by `forward` can be summed | |
across workers prior to calling `reduce_metrics`. Setting this | |
to True will improves distributed training speed. | |
""" | |
return True | |
class Tacotron2Loss(torch.nn.Module): | |
"""Loss function module for Tacotron2.""" | |
def __init__( | |
self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0 | |
): | |
"""Initialize Tactoron2 loss module. | |
Args: | |
use_masking (bool): Whether to apply masking | |
for padded part in loss calculation. | |
use_weighted_masking (bool): | |
Whether to apply weighted masking in loss calculation. | |
bce_pos_weight (float): Weight of positive sample of stop token. | |
""" | |
super(Tacotron2Loss, self).__init__() | |
assert (use_masking != use_weighted_masking) or not use_masking | |
self.use_masking = use_masking | |
self.use_weighted_masking = use_weighted_masking | |
# define criterions | |
# reduction = "none" if self.use_weighted_masking else "sum" | |
reduction = "none" if self.use_weighted_masking else "mean" | |
self.l1_criterion = torch.nn.L1Loss(reduction=reduction) | |
self.mse_criterion = torch.nn.MSELoss(reduction=reduction) | |
self.bce_criterion = torch.nn.BCEWithLogitsLoss( | |
reduction=reduction, pos_weight=torch.tensor(bce_pos_weight) | |
) | |
# NOTE(kan-bayashi): register pre hook function for the compatibility | |
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) | |
def forward(self, after_outs, before_outs, logits, ys, labels, olens): | |
"""Calculate forward propagation. | |
Args: | |
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). | |
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). | |
logits (Tensor): Batch of stop logits (B, Lmax). | |
ys (Tensor): Batch of padded target features (B, Lmax, odim). | |
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). | |
olens (LongTensor): Batch of the lengths of each target (B,). | |
Returns: | |
Tensor: L1 loss value. | |
Tensor: Mean square error loss value. | |
Tensor: Binary cross entropy loss value. | |
""" | |
# make mask and apply it | |
if self.use_masking: | |
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) | |
ys = ys.masked_select(masks) | |
after_outs = after_outs.masked_select(masks) | |
before_outs = before_outs.masked_select(masks) | |
labels = labels.masked_select(masks[:, :, 0]) | |
logits = logits.masked_select(masks[:, :, 0]) | |
# calculate loss | |
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) | |
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( | |
before_outs, ys | |
) | |
bce_loss = self.bce_criterion(logits, labels) | |
# make weighted mask and apply it | |
if self.use_weighted_masking: | |
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) | |
weights = masks.float() / masks.sum(dim=1, keepdim=True).float() | |
out_weights = weights.div(ys.size(0) * ys.size(2)) | |
logit_weights = weights.div(ys.size(0)) | |
# apply weight | |
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() | |
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() | |
bce_loss = ( | |
bce_loss.mul(logit_weights.squeeze(-1)) | |
.masked_select(masks.squeeze(-1)) | |
.sum() | |
) | |
return l1_loss, mse_loss, bce_loss | |
def _load_state_dict_pre_hook( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
"""Apply pre hook fucntion before loading state dict. | |
From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but | |
old models do not include it and as a result, it causes missing key error when | |
loading old model parameter. This function solve the issue by adding param in | |
state dict before loading as a pre hook function | |
of the `load_state_dict` method. | |
""" | |
key = prefix + "bce_criterion.pos_weight" | |
if key not in state_dict: | |
state_dict[key] = self.bce_criterion.pos_weight | |
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): | |
"""Guided attention loss function module for multi head attention. | |
Args: | |
sigma (float, optional): Standard deviation to control | |
how close attention to a diagonal. | |
alpha (float, optional): Scaling coefficient (lambda). | |
reset_always (bool, optional): Whether to always reset masks. | |
""" | |
def forward(self, att_ws, ilens, olens): | |
"""Calculate forward propagation. | |
Args: | |
att_ws (Tensor): | |
Batch of multi head attention weights (B, H, T_max_out, T_max_in). | |
ilens (LongTensor): Batch of input lenghts (B,). | |
olens (LongTensor): Batch of output lenghts (B,). | |
Returns: | |
Tensor: Guided attention loss value. | |
""" | |
if self.guided_attn_masks is None: | |
self.guided_attn_masks = ( | |
self._make_guided_attention_masks(ilens, olens) | |
.to(att_ws.device) | |
.unsqueeze(1) | |
) | |
if self.masks is None: | |
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) | |
losses = self.guided_attn_masks * att_ws | |
loss = torch.mean(losses.masked_select(self.masks)) | |
if self.reset_always: | |
self._reset_masks() | |
return self.alpha * loss | |
def _make_guided_attention_masks(self, ilens, olens): | |
n_batches = len(ilens) | |
max_ilen = max(ilens) | |
max_olen = max(olens) | |
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device) | |
for idx, (ilen, olen) in enumerate(zip(ilens, olens)): | |
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( | |
ilen, olen, self.sigma | |
) | |
return guided_attn_masks | |
def _make_guided_attention_mask(ilen, olen, sigma): | |
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device)) | |
grid_x, grid_y = grid_x.float(), grid_y.float() | |
return 1.0 - torch.exp( | |
-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)) | |
) | |
def _make_masks(ilens, olens): | |
in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in) | |
out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out) | |
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) | |