Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import re | |
from dataclasses import dataclass | |
import math | |
from fairseq import metrics, utils | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
from artst.criterions.text_to_speech_loss import TexttoSpeechLoss | |
from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig | |
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig | |
from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig | |
from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig | |
from fairseq.logging.meters import safe_round | |
class ArTSTCriterionConfig( | |
LabelSmoothedCrossEntropyCriterionConfig, | |
TextPretrainCriterionConfig, | |
SpeechPretrainCriterionConfig, | |
SpeechtoTextLossConfig | |
): | |
pass | |
class ArTSTCriterion(FairseqCriterion): | |
def __init__( | |
self, | |
task, | |
sentence_avg, | |
label_smoothing, | |
pred_masked_weight, | |
pred_nomask_weight, | |
loss_weights=None, | |
log_keys=None, | |
ignore_prefix_size=0, | |
report_accuracy=False, | |
use_masking=True, | |
use_weighted_masking=False, | |
loss_type="L1", | |
bce_pos_weight=5.0, | |
bce_loss_lambda=1.0, | |
use_guided_attn_loss=False, | |
num_heads_applied_guided_attn=2, | |
ce_weight=1.0, | |
ctc_weight=0.0, | |
hubert_weight=1.0, | |
dec_weight=1.0, | |
bart_weight=1.0, | |
): | |
super().__init__(task) | |
self.speech_criterion = TexttoSpeechLoss( | |
task, | |
sentence_avg, | |
use_masking, | |
use_weighted_masking, | |
loss_type, | |
bce_pos_weight, | |
bce_loss_lambda, | |
use_guided_attn_loss, | |
num_heads_applied_guided_attn=num_heads_applied_guided_attn, | |
) | |
self.text_criterion = SpeechtoTextLoss( | |
SpeechtoTextLossConfig, | |
task, | |
sentence_avg, | |
label_smoothing, | |
ignore_prefix_size, | |
report_accuracy, | |
ce_weight, | |
ctc_weight | |
) | |
self.text_pretrain_criterion = TextPretrainCriterion( | |
task, | |
sentence_avg, | |
bart_weight, | |
loss_weights, | |
) | |
self.speech_pretrain_criterion = SpeechPretrainCriterion( | |
task, | |
sentence_avg, | |
pred_masked_weight, | |
pred_nomask_weight, | |
loss_weights, | |
log_keys, | |
use_masking, | |
use_weighted_masking, | |
loss_type, | |
bce_pos_weight, | |
hubert_weight, | |
dec_weight | |
) | |
def forward(self, model, sample, reduce=True): | |
"""Compute the loss for the given sample. | |
Returns a tuple with three elements: | |
1) the loss | |
2) the sample size, which is used as the denominator for the gradient | |
3) logging outputs to display while training | |
""" | |
task_name = sample['task_name'] | |
if task_name == 's2t' or task_name == 's2c': | |
return self.text_criterion(model, sample, reduce) | |
elif task_name == 't2s' or task_name == 's2s': | |
return self.speech_criterion(model, sample) | |
elif task_name == 'text_pretrain': | |
return self.text_pretrain_criterion(model, sample, reduce) | |
elif task_name == 'speech_pretrain': | |
return self.speech_pretrain_criterion(model, sample, reduce) | |
def reduce_metrics(cls, logging_outputs): | |
"""Aggregate logging outputs from data parallel training.""" | |
logging_outputs_dict = {} | |
for logging_output in logging_outputs: | |
for task_name in logging_output: | |
if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']: | |
continue | |
if task_name not in logging_outputs_dict: | |
logging_outputs_dict[task_name] = [] | |
logging_outputs_dict[task_name].append(logging_output[task_name]) | |
for task_name in logging_outputs_dict: | |
if task_name == 's2t': | |
# LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs]) | |
s2t_logging_output = logging_outputs_dict[task_name] | |
# s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs) | |
loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output) | |
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output) | |
ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output) | |
ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output) | |
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output)) | |
metrics.log_scalar( | |
"s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 | |
) | |
metrics.log_scalar( | |
"s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 | |
) | |
metrics.log_derived( | |
"s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2) | |
) | |
metrics.log_scalar( | |
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 | |
) | |
metrics.log_scalar( | |
"ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3 | |
) | |
total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output)) | |
if total > 0: | |
metrics.log_scalar("s2t_total", total) | |
n_correct = utils.item( | |
sum(log.get("n_correct", 0) for log in s2t_logging_output) | |
) | |
metrics.log_scalar("s2t_n_correct", n_correct) | |
metrics.log_derived( | |
"s2t_accuracy", | |
lambda meters: round( | |
meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3 | |
) | |
if meters["s2t_total"].sum > 0 | |
else float("nan"), | |
2 | |
) | |
c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output) | |
metrics.log_scalar("_c_errors", c_errors) | |
c_total = sum(log.get("c_total", 0) for log in s2t_logging_output) | |
metrics.log_scalar("_c_total", c_total) | |
w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output) | |
metrics.log_scalar("_w_errors", w_errors) | |
wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output) | |
metrics.log_scalar("_wv_errors", wv_errors) | |
w_total = sum(log.get("w_total", 0) for log in s2t_logging_output) | |
metrics.log_scalar("_w_total", w_total) | |
if c_total > 0: | |
metrics.log_derived( | |
"uer", | |
lambda meters: safe_round( | |
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 | |
) | |
if meters["_c_total"].sum > 0 | |
else float("nan"), | |
) | |
if w_total > 0: | |
metrics.log_derived( | |
"wer", | |
lambda meters: safe_round( | |
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 | |
) | |
if meters["_w_total"].sum > 0 | |
else float("nan"), | |
) | |
metrics.log_derived( | |
"raw_wer", | |
lambda meters: safe_round( | |
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 | |
) | |
if meters["_w_total"].sum > 0 | |
else float("nan"), | |
) | |
if task_name == 't2s': | |
# TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs]) | |
# t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs) | |
t2s_logging_output = logging_outputs_dict[task_name] | |
loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output) | |
l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output) | |
l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output) | |
bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output)) | |
metrics.log_scalar( | |
"t2s_loss", loss_sum / sample_size, sample_size, 1, round=5 | |
) | |
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output) | |
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output) | |
ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output) | |
metrics.log_scalar( | |
"t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 | |
) | |
metrics.log_scalar( | |
"t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 | |
) | |
if "enc_dec_attn_loss" in t2s_logging_output[0]: | |
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output) | |
metrics.log_scalar( | |
"t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 | |
) | |
if task_name == 's2c': | |
s2c_logging_output = logging_outputs_dict[task_name] | |
loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output) | |
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output) | |
ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output)) | |
metrics.log_scalar( | |
"s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 | |
) | |
metrics.log_scalar( | |
"s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 | |
) | |
total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output)) | |
if total > 0: | |
metrics.log_scalar("s2c_total", total) | |
n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output)) | |
metrics.log_scalar("s2c_n_correct", n_correct) | |
metrics.log_derived( | |
"s2c_accuracy", | |
lambda meters: round( | |
meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3 | |
) | |
if meters["s2c_total"].sum > 0 | |
else float("nan"), | |
2 | |
) | |
if task_name == 's2s': | |
s2s_logging_output = logging_outputs_dict[task_name] | |
loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output) | |
l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output) | |
l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output) | |
bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output)) | |
metrics.log_scalar( | |
"s2s_loss", loss_sum / sample_size, sample_size, 1, round=5 | |
) | |
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output) | |
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output) | |
ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output) | |
metrics.log_scalar( | |
"s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 | |
) | |
if "enc_dec_attn_loss" in s2s_logging_output[0]: | |
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output) | |
metrics.log_scalar( | |
"s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 | |
) | |
if task_name == 'text_pretrain': | |
bart_logging_output = logging_outputs_dict[task_name] | |
loss_sum = sum(log.get("loss", 0) for log in bart_logging_output) | |
ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output)) | |
bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output) | |
# we divide by log(2) to convert the loss from base e to base 2 | |
metrics.log_scalar( | |
"text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 | |
) | |
metrics.log_scalar( | |
"bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 | |
) | |
if sample_size != ntokens: | |
metrics.log_scalar( | |
"bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3 | |
) | |
metrics.log_derived( | |
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg) | |
) | |
else: | |
metrics.log_derived( | |
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg) | |
) | |
metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1) | |
val_prob_perplexity = 0 | |
val_code_perplexity = 0 | |
sample_size_pp = 0 | |
count_log_cp = 0 | |
for log in bart_logging_output: | |
if "loss_prob_perplexity" in log: | |
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] | |
sample_size_pp = sample_size_pp + log["sample_size"] | |
if "code_perplexity" in log: | |
val_code_perplexity = val_code_perplexity + log["code_perplexity"] | |
count_log_cp = count_log_cp + 1 | |
if val_prob_perplexity > 0: | |
metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) | |
if val_code_perplexity > 0: | |
metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3) | |
if task_name == 'speech_pretrain': | |
hubert_logging_output = logging_outputs_dict[task_name] | |
loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output) | |
ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output)) | |
dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output) | |
l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output) | |
l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output) | |
bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output) | |
ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output) | |
metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3) | |
if sample_size != ntokens: | |
metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) | |
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg)) | |
else: | |
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg)) | |
counts = {} | |
for lk in hubert_logging_output[0].keys(): | |
if lk.startswith("count_"): | |
val = sum(log[lk] for log in hubert_logging_output) | |
metrics.log_scalar("hubert_" + lk, val) | |
counts[lk] = val | |
for lk in hubert_logging_output[0].keys(): | |
if lk.startswith("loss_") and lk != 'loss_prob_perplexity': | |
val = sum(log[lk] for log in hubert_logging_output) | |
metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3) | |
elif lk.startswith("correct_"): | |
val = sum(log[lk] for log in hubert_logging_output) | |
metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)]) | |
# elif lk == 'code_perplexity': | |
# val = sum(log[lk] for log in hubert_logging_output) | |
# metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3) | |
val_prob_perplexity = 0 | |
val_code_perplexity = 0 | |
sample_size_pp = 0 | |
count_log_cp = 0 | |
for log in hubert_logging_output: | |
if "loss_prob_perplexity" in log: | |
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] | |
sample_size_pp = sample_size_pp + log["sample_size"] | |
if "code_perplexity" in log: | |
val_code_perplexity = val_code_perplexity + log["code_perplexity"] | |
count_log_cp = count_log_cp + 1 | |
if val_prob_perplexity > 0: | |
metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) | |
if val_code_perplexity > 0: | |
metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3) | |
metrics.log_scalar( | |
"hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5 | |
) | |
metrics.log_scalar( | |
"hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5 | |
) | |
if "enc_dec_attn_loss" in hubert_logging_output[0]: | |
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output) | |
metrics.log_scalar( | |
"hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8 | |
) | |
metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1) | |
loss = sum(log.get("loss", 0) for log in logging_outputs) | |
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) | |
metrics.log_scalar( | |
"loss", loss / sample_size, sample_size, 1, round=5 | |
) | |
def logging_outputs_can_be_summed() -> bool: | |
""" | |
Whether the logging outputs returned by `forward` can be summed | |
across workers prior to calling `reduce_metrics`. Setting this | |
to True will improves distributed training speed. | |
""" | |
return False | |