# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import re from dataclasses import dataclass import math from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from artst.criterions.text_to_speech_loss import TexttoSpeechLoss from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig from fairseq.logging.meters import safe_round @dataclass class ArTSTCriterionConfig( LabelSmoothedCrossEntropyCriterionConfig, TextPretrainCriterionConfig, SpeechPretrainCriterionConfig, SpeechtoTextLossConfig ): pass @register_criterion( "artst", dataclass=ArTSTCriterionConfig ) class ArTSTCriterion(FairseqCriterion): def __init__( self, task, sentence_avg, label_smoothing, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None, ignore_prefix_size=0, report_accuracy=False, use_masking=True, use_weighted_masking=False, loss_type="L1", bce_pos_weight=5.0, bce_loss_lambda=1.0, use_guided_attn_loss=False, num_heads_applied_guided_attn=2, ce_weight=1.0, ctc_weight=0.0, hubert_weight=1.0, dec_weight=1.0, bart_weight=1.0, ): super().__init__(task) self.speech_criterion = TexttoSpeechLoss( task, sentence_avg, use_masking, use_weighted_masking, loss_type, bce_pos_weight, bce_loss_lambda, use_guided_attn_loss, num_heads_applied_guided_attn=num_heads_applied_guided_attn, ) self.text_criterion = SpeechtoTextLoss( SpeechtoTextLossConfig, task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy, ce_weight, ctc_weight ) self.text_pretrain_criterion = TextPretrainCriterion( task, sentence_avg, bart_weight, loss_weights, ) self.speech_pretrain_criterion = SpeechPretrainCriterion( task, sentence_avg, pred_masked_weight, pred_nomask_weight, loss_weights, log_keys, use_masking, use_weighted_masking, loss_type, bce_pos_weight, hubert_weight, dec_weight ) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ task_name = sample['task_name'] if task_name == 's2t' or task_name == 's2c': return self.text_criterion(model, sample, reduce) elif task_name == 't2s' or task_name == 's2s': return self.speech_criterion(model, sample) elif task_name == 'text_pretrain': return self.text_pretrain_criterion(model, sample, reduce) elif task_name == 'speech_pretrain': return self.speech_pretrain_criterion(model, sample, reduce) @classmethod def reduce_metrics(cls, logging_outputs): """Aggregate logging outputs from data parallel training.""" logging_outputs_dict = {} for logging_output in logging_outputs: for task_name in logging_output: if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']: continue if task_name not in logging_outputs_dict: logging_outputs_dict[task_name] = [] logging_outputs_dict[task_name].append(logging_output[task_name]) for task_name in logging_outputs_dict: if task_name == 's2t': # LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs]) s2t_logging_output = logging_outputs_dict[task_name] # s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs) loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output) nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output) ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output) ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output) ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output)) metrics.log_scalar( "s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 ) metrics.log_scalar( "s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 ) metrics.log_derived( "s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2) ) metrics.log_scalar( "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 ) metrics.log_scalar( "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3 ) total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output)) if total > 0: metrics.log_scalar("s2t_total", total) n_correct = utils.item( sum(log.get("n_correct", 0) for log in s2t_logging_output) ) metrics.log_scalar("s2t_n_correct", n_correct) metrics.log_derived( "s2t_accuracy", lambda meters: round( meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3 ) if meters["s2t_total"].sum > 0 else float("nan"), 2 ) c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output) metrics.log_scalar("_c_errors", c_errors) c_total = sum(log.get("c_total", 0) for log in s2t_logging_output) metrics.log_scalar("_c_total", c_total) w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output) metrics.log_scalar("_w_errors", w_errors) wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output) metrics.log_scalar("_wv_errors", wv_errors) w_total = sum(log.get("w_total", 0) for log in s2t_logging_output) metrics.log_scalar("_w_total", w_total) if c_total > 0: metrics.log_derived( "uer", lambda meters: safe_round( meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 ) if meters["_c_total"].sum > 0 else float("nan"), ) if w_total > 0: metrics.log_derived( "wer", lambda meters: safe_round( meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 ) if meters["_w_total"].sum > 0 else float("nan"), ) metrics.log_derived( "raw_wer", lambda meters: safe_round( meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 ) if meters["_w_total"].sum > 0 else float("nan"), ) if task_name == 't2s': # TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs]) # t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs) t2s_logging_output = logging_outputs_dict[task_name] loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output) l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output) l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output) bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output)) metrics.log_scalar( "t2s_loss", loss_sum / sample_size, sample_size, 1, round=5 ) encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output) decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output) ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output) metrics.log_scalar( "t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 ) metrics.log_scalar( "t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 ) if "enc_dec_attn_loss" in t2s_logging_output[0]: enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output) metrics.log_scalar( "t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 ) if task_name == 's2c': s2c_logging_output = logging_outputs_dict[task_name] loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output) nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output) ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output)) metrics.log_scalar( "s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 ) metrics.log_scalar( "s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 ) total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output)) if total > 0: metrics.log_scalar("s2c_total", total) n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output)) metrics.log_scalar("s2c_n_correct", n_correct) metrics.log_derived( "s2c_accuracy", lambda meters: round( meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3 ) if meters["s2c_total"].sum > 0 else float("nan"), 2 ) if task_name == 's2s': s2s_logging_output = logging_outputs_dict[task_name] loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output) l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output) l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output) bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output)) metrics.log_scalar( "s2s_loss", loss_sum / sample_size, sample_size, 1, round=5 ) encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output) decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output) ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output) metrics.log_scalar( "s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 ) metrics.log_scalar( "s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 ) if "enc_dec_attn_loss" in s2s_logging_output[0]: enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output) metrics.log_scalar( "s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 ) if task_name == 'text_pretrain': bart_logging_output = logging_outputs_dict[task_name] loss_sum = sum(log.get("loss", 0) for log in bart_logging_output) ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output)) bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output) # we divide by log(2) to convert the loss from base e to base 2 metrics.log_scalar( "text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 ) metrics.log_scalar( "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 ) if sample_size != ntokens: metrics.log_scalar( "bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3 ) metrics.log_derived( "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg) ) else: metrics.log_derived( "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg) ) metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1) val_prob_perplexity = 0 val_code_perplexity = 0 sample_size_pp = 0 count_log_cp = 0 for log in bart_logging_output: if "loss_prob_perplexity" in log: val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] sample_size_pp = sample_size_pp + log["sample_size"] if "code_perplexity" in log: val_code_perplexity = val_code_perplexity + log["code_perplexity"] count_log_cp = count_log_cp + 1 if val_prob_perplexity > 0: metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) if val_code_perplexity > 0: metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3) if task_name == 'speech_pretrain': hubert_logging_output = logging_outputs_dict[task_name] loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output) ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output) sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output)) dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output) l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output) l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output) bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output) ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output) metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3) if sample_size != ntokens: metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg)) else: metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg)) counts = {} for lk in hubert_logging_output[0].keys(): if lk.startswith("count_"): val = sum(log[lk] for log in hubert_logging_output) metrics.log_scalar("hubert_" + lk, val) counts[lk] = val for lk in hubert_logging_output[0].keys(): if lk.startswith("loss_") and lk != 'loss_prob_perplexity': val = sum(log[lk] for log in hubert_logging_output) metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3) elif lk.startswith("correct_"): val = sum(log[lk] for log in hubert_logging_output) metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)]) # elif lk == 'code_perplexity': # val = sum(log[lk] for log in hubert_logging_output) # metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3) val_prob_perplexity = 0 val_code_perplexity = 0 sample_size_pp = 0 count_log_cp = 0 for log in hubert_logging_output: if "loss_prob_perplexity" in log: val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] sample_size_pp = sample_size_pp + log["sample_size"] if "code_perplexity" in log: val_code_perplexity = val_code_perplexity + log["code_perplexity"] count_log_cp = count_log_cp + 1 if val_prob_perplexity > 0: metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) if val_code_perplexity > 0: metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3) metrics.log_scalar( "hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5 ) metrics.log_scalar( "hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5 ) metrics.log_scalar( "hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5 ) metrics.log_scalar( "hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5 ) if "enc_dec_attn_loss" in hubert_logging_output[0]: enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output) metrics.log_scalar( "hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8 ) metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1) loss = sum(log.get("loss", 0) for log in logging_outputs) sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) metrics.log_scalar( "loss", loss / sample_size, sample_size, 1, round=5 ) @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return False