diff --git a/README.md b/README.md index 58bf79964e44a8a8b93dd65a90ae787d2c13a077..2db49f26cd53d5b013ad7ee11aa59acc0bae995b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ --- -title: Artst +title: ArtstTTS emoji: 🔥 colorFrom: yellow colorTo: gray @@ -7,6 +7,7 @@ sdk: gradio sdk_version: 4.7.1 app_file: app.py pinned: false +python_version: 3.8.2 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa3656b446c3b673ae271aedf2cfb50d6004bbd --- /dev/null +++ b/app.py @@ -0,0 +1,59 @@ +import os +import torch +import gradio as gr +import os.path as op +import pyarabic.araby as araby + +from artst.tasks.artst import ArTSTTask +from transformers import SpeechT5HifiGan +from artst.models.artst import ArTSTTransformerModel +from fairseq.tasks.hubert_pretraining import LabelEncoder +from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +WORK_DIR = os.getcwd() +checkpoint = torch.load('ckpts/clartts_tts.pt') +checkpoint['cfg']['task'].t5_task = 't2s' +task = ArTSTTask.setup_task(checkpoint['cfg']['task']) + +emb_path='embs/clartts.npy' +model = ArTSTTransformerModel.build_model(checkpoint['cfg']['model'], task) +model.load_state_dict(checkpoint['model']) + +checkpoint['cfg']['task'].bpe_tokenizer = task.build_bpe(checkpoint['cfg']['model']) +tokenizer = checkpoint['cfg']['task'].bpe_tokenizer + +processor = LabelEncoder(task.dicts['text']) + +vocoder = SpeechT5HifiGan.from_pretrained('microsoft/speecht5_hifigan').to(device) + +def get_embs(emb_path): + spkembs = get_features_or_waveform(emb_path) + spkembs = torch.from_numpy(spkembs).float().unsqueeze(0) + return spkembs + +def process_text(text): + text = araby.strip_diacritics(text) + return processor(tokenizer.encode(text)).reshape(1, -1) + +net_input = {} + +def inference(text, spkr=emb_path): + net_input['src_tokens'] = process_text(text) + net_input['spkembs'] = get_embs(spkr) + outs, _, attn = task.generate_speech( + [model], + net_input, + ) + with torch.no_grad(): + gen_audio = vocoder(outs.to(device)) + return (16000,gen_audio.cpu().numpy()) + +text_box = gr.Textbox(max_lines=2, label="Arabic Text") +out = gr.Audio(label="Synthesized Audio", type="numpy") +demo = gr.Interface(inference, \ + inputs=text_box, outputs=out, title="ArTST") + +if __name__ == "__main__": + demo.launch(share=True) diff --git a/artst/__init__.py b/artst/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8994f9a368ae4b2eff720fffb134e2a5b813ee1c --- /dev/null +++ b/artst/__init__.py @@ -0,0 +1 @@ +from . import data, tasks, criterions, models # noqa \ No newline at end of file diff --git a/artst/__pycache__/__init__.cpython-38.pyc b/artst/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c0edcb062c45775b90dfe65836e21e1549c5f0c Binary files /dev/null and b/artst/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/__pycache__/sequence_generator.cpython-38.pyc b/artst/__pycache__/sequence_generator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fbc3f425f62c9bc722ae01e4beaace5d9d23c75 Binary files /dev/null and b/artst/__pycache__/sequence_generator.cpython-38.pyc differ diff --git a/artst/criterions/__init__.py b/artst/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..228236e19780255b95d21f0accdeded5e4355cdf --- /dev/null +++ b/artst/criterions/__init__.py @@ -0,0 +1,10 @@ +import importlib +import os + + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + criterion_name = file[: file.find(".py")] + importlib.import_module( + "artst.criterions." + criterion_name + ) \ No newline at end of file diff --git a/artst/criterions/__pycache__/__init__.cpython-38.pyc b/artst/criterions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..949dd993bd4900b60640edab4a5fce983d16f6bf Binary files /dev/null and b/artst/criterions/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/artst_criterion.cpython-38.pyc b/artst/criterions/__pycache__/artst_criterion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cddbbfbc00dc2fc89da0d32a482c9c42d740c83 Binary files /dev/null and b/artst/criterions/__pycache__/artst_criterion.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc b/artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a9bfa1d8d53b136a40bc0f77d92068eb82f85a Binary files /dev/null and b/artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc b/artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04659ac81603802d2a039927a1e4af999c770984 Binary files /dev/null and b/artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc b/artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94377dfec240d6d4d59b6505d1e2c960d9f188a1 Binary files /dev/null and b/artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc b/artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bce15660c2d0d6055f99aef06a25d791762e9df Binary files /dev/null and b/artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc differ diff --git a/artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc b/artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c4554d77a631f3ec3b7ff3778f1a7e4c1c8195 Binary files /dev/null and b/artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc differ diff --git a/artst/criterions/artst_criterion.py b/artst/criterions/artst_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..5260fbaa06854c8197fe06c345a86e7b86c27e75 --- /dev/null +++ b/artst/criterions/artst_criterion.py @@ -0,0 +1,443 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import re +from dataclasses import dataclass + +import math +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from artst.criterions.text_to_speech_loss import TexttoSpeechLoss +from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig +from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig +from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig +from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig +from fairseq.logging.meters import safe_round + +@dataclass +class ArTSTCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig, + TextPretrainCriterionConfig, + SpeechPretrainCriterionConfig, + SpeechtoTextLossConfig + ): + pass + +@register_criterion( + "artst", dataclass=ArTSTCriterionConfig +) +class ArTSTCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + pred_masked_weight, + pred_nomask_weight, + loss_weights=None, + log_keys=None, + ignore_prefix_size=0, + report_accuracy=False, + use_masking=True, + use_weighted_masking=False, + loss_type="L1", + bce_pos_weight=5.0, + bce_loss_lambda=1.0, + use_guided_attn_loss=False, + num_heads_applied_guided_attn=2, + ce_weight=1.0, + ctc_weight=0.0, + hubert_weight=1.0, + dec_weight=1.0, + bart_weight=1.0, + ): + super().__init__(task) + self.speech_criterion = TexttoSpeechLoss( + task, + sentence_avg, + use_masking, + use_weighted_masking, + loss_type, + bce_pos_weight, + bce_loss_lambda, + use_guided_attn_loss, + num_heads_applied_guided_attn=num_heads_applied_guided_attn, + ) + self.text_criterion = SpeechtoTextLoss( + SpeechtoTextLossConfig, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + ce_weight, + ctc_weight + ) + self.text_pretrain_criterion = TextPretrainCriterion( + task, + sentence_avg, + bart_weight, + loss_weights, + ) + self.speech_pretrain_criterion = SpeechPretrainCriterion( + task, + sentence_avg, + pred_masked_weight, + pred_nomask_weight, + loss_weights, + log_keys, + use_masking, + use_weighted_masking, + loss_type, + bce_pos_weight, + hubert_weight, + dec_weight + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + task_name = sample['task_name'] + if task_name == 's2t' or task_name == 's2c': + return self.text_criterion(model, sample, reduce) + elif task_name == 't2s' or task_name == 's2s': + return self.speech_criterion(model, sample) + elif task_name == 'text_pretrain': + return self.text_pretrain_criterion(model, sample, reduce) + elif task_name == 'speech_pretrain': + return self.speech_pretrain_criterion(model, sample, reduce) + + @classmethod + def reduce_metrics(cls, logging_outputs): + """Aggregate logging outputs from data parallel training.""" + logging_outputs_dict = {} + for logging_output in logging_outputs: + for task_name in logging_output: + if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']: + continue + + if task_name not in logging_outputs_dict: + logging_outputs_dict[task_name] = [] + logging_outputs_dict[task_name].append(logging_output[task_name]) + + for task_name in logging_outputs_dict: + if task_name == 's2t': + # LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs]) + s2t_logging_output = logging_outputs_dict[task_name] + # s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output) + ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output) + ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output) + ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output) + + sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output)) + metrics.log_scalar( + "s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 + ) + + metrics.log_scalar( + "s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 + ) + metrics.log_derived( + "s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2) + ) + metrics.log_scalar( + "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 + ) + metrics.log_scalar( + "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3 + ) + + total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output)) + if total > 0: + metrics.log_scalar("s2t_total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in s2t_logging_output) + ) + metrics.log_scalar("s2t_n_correct", n_correct) + metrics.log_derived( + "s2t_accuracy", + lambda meters: round( + meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3 + ) + if meters["s2t_total"].sum > 0 + else float("nan"), + 2 + ) + c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in s2t_logging_output) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in s2t_logging_output) + metrics.log_scalar("_w_total", w_total) + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + if task_name == 't2s': + # TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs]) + # t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs) + t2s_logging_output = logging_outputs_dict[task_name] + loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output) + l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output) + l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output) + bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output) + sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output)) + metrics.log_scalar( + "t2s_loss", loss_sum / sample_size, sample_size, 1, round=5 + ) + encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output) + decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output) + ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output) + + metrics.log_scalar( + "t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 + ) + metrics.log_scalar( + "t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 + ) + + if "enc_dec_attn_loss" in t2s_logging_output[0]: + enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output) + metrics.log_scalar( + "t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 + ) + + if task_name == 's2c': + s2c_logging_output = logging_outputs_dict[task_name] + loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output) + ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output) + + sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output)) + metrics.log_scalar( + "s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3 + ) + + metrics.log_scalar( + "s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 + ) + + total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output)) + if total > 0: + metrics.log_scalar("s2c_total", total) + n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output)) + metrics.log_scalar("s2c_n_correct", n_correct) + metrics.log_derived( + "s2c_accuracy", + lambda meters: round( + meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3 + ) + if meters["s2c_total"].sum > 0 + else float("nan"), + 2 + ) + + if task_name == 's2s': + s2s_logging_output = logging_outputs_dict[task_name] + loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output) + l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output) + l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output) + bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output) + sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output)) + metrics.log_scalar( + "s2s_loss", loss_sum / sample_size, sample_size, 1, round=5 + ) + encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output) + decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output) + ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output) + + metrics.log_scalar( + "s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 + ) + + if "enc_dec_attn_loss" in s2s_logging_output[0]: + enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output) + metrics.log_scalar( + "s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 + ) + + if task_name == 'text_pretrain': + bart_logging_output = logging_outputs_dict[task_name] + loss_sum = sum(log.get("loss", 0) for log in bart_logging_output) + ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output) + sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output)) + bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg) + ) + else: + metrics.log_derived( + "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg) + ) + metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1) + + val_prob_perplexity = 0 + val_code_perplexity = 0 + sample_size_pp = 0 + count_log_cp = 0 + for log in bart_logging_output: + if "loss_prob_perplexity" in log: + val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] + sample_size_pp = sample_size_pp + log["sample_size"] + if "code_perplexity" in log: + val_code_perplexity = val_code_perplexity + log["code_perplexity"] + count_log_cp = count_log_cp + 1 + if val_prob_perplexity > 0: + metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) + if val_code_perplexity > 0: + metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3) + + if task_name == 'speech_pretrain': + hubert_logging_output = logging_outputs_dict[task_name] + loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output) + ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output) + sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output)) + dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output) + l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output) + l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output) + bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output) + ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output) + + metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + if sample_size != ntokens: + metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg)) + else: + metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg)) + + counts = {} + for lk in hubert_logging_output[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in hubert_logging_output) + metrics.log_scalar("hubert_" + lk, val) + counts[lk] = val + + for lk in hubert_logging_output[0].keys(): + if lk.startswith("loss_") and lk != 'loss_prob_perplexity': + val = sum(log[lk] for log in hubert_logging_output) + metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in hubert_logging_output) + metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)]) + # elif lk == 'code_perplexity': + # val = sum(log[lk] for log in hubert_logging_output) + # metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3) + + val_prob_perplexity = 0 + val_code_perplexity = 0 + sample_size_pp = 0 + count_log_cp = 0 + for log in hubert_logging_output: + if "loss_prob_perplexity" in log: + val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"] + sample_size_pp = sample_size_pp + log["sample_size"] + if "code_perplexity" in log: + val_code_perplexity = val_code_perplexity + log["code_perplexity"] + count_log_cp = count_log_cp + 1 + if val_prob_perplexity > 0: + metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3) + if val_code_perplexity > 0: + metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3) + + metrics.log_scalar( + "hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5 + ) + if "enc_dec_attn_loss" in hubert_logging_output[0]: + enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output) + metrics.log_scalar( + "hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8 + ) + metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1) + + loss = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) + metrics.log_scalar( + "loss", loss / sample_size, sample_size, 1, round=5 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/artst/criterions/speech_pretrain_criterion.py b/artst/criterions/speech_pretrain_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..392c9a60a209ccab2c12033a84c35bf1a7411f5e --- /dev/null +++ b/artst/criterions/speech_pretrain_criterion.py @@ -0,0 +1,265 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion +from artst.criterions.text_to_speech_loss import TexttoSpeechLoss, TexttoSpeechLossConfig + + +@dataclass +class SpeechPretrainCriterionConfig(TexttoSpeechLossConfig): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default_factory=lambda: [10,], + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + hubert_weight: float = field( + default=1.0, + metadata={"help": "weight of hubert loss"}, + ) + dec_weight: float = field( + default=1.0, + metadata={"help": "weight of decoder loss"}, + ) + + +class SpeechPretrainCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + pred_masked_weight, + pred_nomask_weight, + loss_weights=None, + log_keys=None, + use_masking=True, + use_weighted_masking=False, + loss_type="L1", + bce_pos_weight=5.0, + hubert_weight=1.0, + dec_weight=1.0, + ): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + self.hubert_weight = hubert_weight + self.dec_weight = dec_weight + + self.speech_criterion = TexttoSpeechLoss( + task, + sentence_avg, + use_masking, + use_weighted_masking, + loss_type, + bce_pos_weight, + ) + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if self.dec_weight == 0: + sample["net_input"]["only_hubert"] = True + net_output, net_output_dec = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0. + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = model.get_logits(net_output, True) + targ_m_list = model.get_targets(None, net_output, True) + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list = model.get_logits(net_output, False) + targ_u_list = model.get_targets(None, net_output, False) + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + if len(self.loss_weights) > len(extra_losses): + modified_loss_weight = self.loss_weights[:len(extra_losses)] + else: + modified_loss_weight = self.loss_weights + + # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, modified_loss_weight): + # print(n + str(coef)) + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.detach().item() + + logging_output = { + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + "ngpu": 1, + **logging_output, + } + + if 'loss_prob_perplexity' in logging_output: + logging_output['code_perplexity'] = net_output['code_perplexity'].detach().item() + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk].item())) + + def compute_correct(logits): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + corr_m, count_m = compute_correct(logp_m) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + corr_u, count_u = compute_correct(logp_u) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + if self.dec_weight == 0.0: + logging_output["loss"] = loss.item() if reduce else loss + return loss, sample_size, logging_output + +# ## dec loss + dec_loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.speech_criterion.compute_loss(model, net_output_dec, sample) + + # Log tts loss + logging_output['dec_loss'] = dec_loss.item() + logging_output['l1_loss'] = l1_loss.item() + logging_output['l2_loss'] = l2_loss.item() + logging_output['bce_loss'] = bce_loss.item() + if enc_dec_attn_loss is not None: + logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item() + + loss = self.hubert_weight * loss + self.dec_weight * sample_size * dec_loss + logging_output["loss"] = loss.item() if reduce else loss + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs) + l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs) + l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs) + bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs) + ngpu = sum(log.get("ngpu", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + if sample_size != ntokens: + metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) + else: + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + elif lk == 'code_perplexity': + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / len(logging_outputs), round=3) + + metrics.log_scalar( + "dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5 + ) + metrics.log_scalar( + "bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5 + ) + if "enc_dec_attn_loss" in logging_outputs[0]: + enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs) + metrics.log_scalar( + "enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8 + ) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/artst/criterions/speech_to_text_loss.py b/artst/criterions/speech_to_text_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5f6c42def1e1470ce6c9933b3c963fcfa1053b --- /dev/null +++ b/artst/criterions/speech_to_text_loss.py @@ -0,0 +1,473 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +from argparse import Namespace +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import post_process +from fairseq.tasks import FairseqTask +from fairseq.logging.meters import safe_round + +import logging +logger = logging.getLogger(__name__) + +@dataclass +class SpeechtoTextLossConfig(FairseqDataclass): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + post_process: Optional[str] = field( + default="sentencepiece", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + wer_kenlm_model: Optional[str] = field( + default=None, + metadata={ + "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)" + }, + ) + wer_lexicon: Optional[str] = field( + default=None, + metadata={"help": "lexicon to use with wer_kenlm_model"}, + ) + wer_lm_weight: float = field( + default=2.0, + metadata={"help": "lm weight to use with wer_kenlm_model"}, + ) + wer_word_score: float = field( + default=-1.0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + + wer_args: Optional[str] = field( + default=None, + metadata={ + "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)" + }, + ) + + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + report_accuracy: bool = field( + default=False, + metadata={"help": "report accuracy metric"}, + ) + ignore_prefix_size: int = field( + default=0, + metadata={"help": "Ignore first N tokens"}, + ) + #: bool = II("optimization.sentence_avg") + + ce_weight: float = field( + default=1.0, + metadata={"help": "loss weight for cross entropy"}, + ) + ctc_weight: float = field( + default=0.0, + metadata={"help": "loss weiehgt for ctc in ASR"}, + ) + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +class SpeechtoTextLoss(FairseqCriterion): + def __init__( + self, + cfg: SpeechtoTextLossConfig, + task: FairseqTask, + sentence_avg=True, + label_smoothing=0.1, + ignore_prefix_size=0, + report_accuracy=False, + ce_weight=1.0, + ctc_weight=0.0, + ): + + super().__init__(task) + self.blank_idx = ( + task.target_dictionary.index(task.blank_symbol) + if hasattr(task, "blank_symbol") + else 0 + ) + #print ("self.blank_idx: ", self.blank_idx) + + self.pad_idx = task.target_dictionary.pad() + self.eos_idx = task.target_dictionary.eos() + self.post_process = cfg.post_process + self.ce_weight = ce_weight + self.ctc_weight = ctc_weight + + ## for ce + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + + if cfg.wer_args is not None: + ( + cfg.wer_kenlm_model, + cfg.wer_lexicon, + cfg.wer_lm_weight, + cfg.wer_word_score, + ) = eval(cfg.wer_args) + + if cfg.wer_kenlm_model is not None: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + dec_args = Namespace() + dec_args.nbest = 1 + dec_args.criterion = "ctc" + dec_args.kenlm_model = cfg.wer_kenlm_model + dec_args.lexicon = cfg.wer_lexicon + dec_args.beam = 50 + dec_args.beam_size_token = min(50, len(task.target_dictionary)) + dec_args.beam_threshold = min(50, len(task.target_dictionary)) + dec_args.lm_weight = cfg.wer_lm_weight + dec_args.word_score = cfg.wer_word_score + dec_args.unk_weight = -math.inf + dec_args.sil_weight = 0 + + self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary) + else: + self.w2l_decoder = None + + self.zero_infinity = cfg.zero_infinity + #self.sentence_avg = cfg.sentence_avg + + if self.ce_weight > 0 and self.ctc_weight > 0: + logger.info("Using cross entropy loss and CTC loss for ASR") + elif self.ce_weight > 0: + logger.info("Only using CE loss") + elif self.ctc_weight > 0: + logger.info("Only using CTC loss for ASR") + else: + logger.info("ERROR") + + def forward(self, model, sample, reduce=True): + + if self.ce_weight == 0 and self.ctc_weight > 0: + sample["only_ctc"] = True + + net_output_decoder, net_output = model(**sample["net_input"]) + + if self.ce_weight > 0: + loss_ce, nll_loss_ce = self.compute_loss(model, net_output_decoder, sample, reduce=reduce) + #print ("loss_ce: ", loss_ce) + else: + nll_loss_ce = None + + if self.ctc_weight > 0: + loss_ctc, lprobs, input_lengths = self.compute_loss_ctc(model, net_output, sample) + + if self.ce_weight > 0 and self.ctc_weight > 0: + loss = self.ce_weight * loss_ce + self.ctc_weight * loss_ctc + elif self.ce_weight > 0: + loss = loss_ce + elif self.ctc_weight > 0: + loss = loss_ctc + else: + logger.info("ERROR: must ce_weight > 0 or ctc_weight > 0") + + ntokens = ( + sample["ntokens"] if "ntokens" in sample else sample["target_lengths"].sum().item() + ) + + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + + logging_output = { + "loss": loss.item(), + "ce_loss": loss_ce.item() if self.ce_weight > 0 else 0, + "ctc_loss": loss_ctc.item() if self.ctc_weight > 0 else 0, + "nll_loss": nll_loss_ce.item() if nll_loss_ce is not None else 0, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + + if self.ce_weight > 0 and self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output_decoder, sample) + logging_output["n_correct"] = utils.item(n_correct.item()) + logging_output["total"] = utils.item(total.data) + + if self.ctc_weight > 0 and not model.training: + import editdistance + + with torch.no_grad(): + lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() + + c_err = 0 + c_len = 0 + w_errs = 0 + w_len = 0 + wv_errs = 0 + for lp, t, inp_l in zip( + lprobs_t, + sample["target_label"] + if "target_label" in sample + else sample["target"], + input_lengths, + ): + lp = lp[:inp_l].unsqueeze(0) + + decoded = None + if self.w2l_decoder is not None: + decoded = self.w2l_decoder.decode(lp) + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + + p = (t != self.task.target_dictionary.pad()) & ( + t != self.task.target_dictionary.eos() + ) + targ = t[p] + targ_units = self.task.target_dictionary.string(targ) + targ_units_arr = targ.tolist() + + toks = lp.argmax(dim=-1).unique_consecutive() + pred_units_arr = toks[toks != self.blank_idx].tolist() + + c_err += editdistance.eval(pred_units_arr, targ_units_arr) + c_len += len(targ_units_arr) + + targ_words = post_process(targ_units, self.post_process).split() + + pred_units = self.task.target_dictionary.string(pred_units_arr) + pred_words_raw = post_process(pred_units, self.post_process).split() + + if decoded is not None and "words" in decoded: + pred_words = decoded["words"] + w_errs += editdistance.eval(pred_words, targ_words) + wv_errs += editdistance.eval(pred_words_raw, targ_words) + else: + dist = editdistance.eval(pred_words_raw, targ_words) + w_errs += dist + wv_errs += dist + + w_len += len(targ_words) + + logging_output["wv_errors"] = wv_errs + logging_output["w_errors"] = w_errs + logging_output["w_total"] = w_len + logging_output["c_errors"] = c_err + logging_output["c_total"] = c_len + + return loss, sample_size, logging_output + + def compute_loss_ctc(self, model, net_output, sample): + lprobs = model.get_normalized_probs_for_ctc( + net_output, log_probs=True + ).contiguous() # (T, B, C) from the encoder + + if net_output["encoder_padding_mask"] is not None: + non_padding_mask = ~net_output["encoder_padding_mask"][0] + input_lengths = non_padding_mask.long().sum(-1) + else: + input_lengths = lprobs.new_full( + (lprobs.size(1),), lprobs.size(0), dtype=torch.long + ) + + pad_mask = (sample["target"] != self.pad_idx) & ( + sample["target"] != self.eos_idx + ) + targets_flat = sample["target"].masked_select(pad_mask) + if "target_lengths" in sample: + target_lengths = sample["target_lengths"] + else: + target_lengths = pad_mask.sum(-1) + + ##processing + target_lengths = target_lengths - 1 + + with torch.backends.cudnn.flags(enabled=False): + loss_ctc = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction="sum", + zero_infinity=True, + ) + + return loss_ctc, lprobs, input_lengths + + ## for ce + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.ignore_prefix_size > 0: + if getattr(lprobs, "batch_first", False): + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + else: + lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() + target = target[self.ignore_prefix_size :, :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + + def compute_accuracy(self, model, net_output, sample): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) + ) + total = torch.sum(mask) + return n_correct, total + + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ce_loss_sum = sum(log.get("ce_loss", 0) for log in logging_outputs) + ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + + metrics.log_scalar( + "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 + ) + metrics.log_scalar( + "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, 2) + ) + + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar("total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in logging_outputs) + ) + metrics.log_scalar("n_correct", n_correct) + metrics.log_derived( + "accuracy", + lambda meters: round( + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), + 2 + ) + + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/artst/criterions/text_pretrain_criterion.py b/artst/criterions/text_pretrain_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6d3b78316891ebf9e337e4f60e79eba9d55464 --- /dev/null +++ b/artst/criterions/text_pretrain_criterion.py @@ -0,0 +1,142 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class TextPretrainCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + loss_weights: Optional[List[float]] = field( + default_factory=lambda: [0.1,], + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + bart_weight: float = field( + default=1.0, + metadata={"help": "loss weight for cross entropy"}, + ) + + +class TextPretrainCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg, bart_weight, loss_weights=None): + super().__init__(task) + self.sentence_avg = sentence_avg + self.loss_weights = loss_weights + self.bart_weight = bart_weight + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output, codebook_out, encoder_output = model(**sample["net_input"]) + bart_loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + loss = self.bart_weight * bart_loss + logging_output = { + "loss": loss.item(), + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "bart_loss": bart_loss.item(), + "sample_size": sample_size, + } + + if "prob_perplexity" in codebook_out: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(codebook_out) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + if len(self.loss_weights) > len(extra_losses): + modified_loss_weight = self.loss_weights[len(extra_losses):] + else: + modified_loss_weight = self.loss_weights + + # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, modified_loss_weight): + # print(n + str(coef)) + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + if 'loss_prob_perplexity' in logging_output: + logging_output['code_perplexity'] = codebook_out['code_perplexity'].item() + + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1) + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + return loss, loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + bart_loss_sum = sum(log.get("bart_loss", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg) + ) + + if "loss_prob_perplexity" in logging_outputs[0].keys(): + val = sum(log["loss_prob_perplexity"] for log in logging_outputs) + metrics.log_scalar("loss_prob_perplexity", val / sample_size / math.log(2), round=3) + if "code_perplexity" in logging_outputs[0].keys(): + val = sum(log["code_perplexity"] for log in logging_outputs) + metrics.log_scalar("code_perplexity", val / len(logging_outputs), round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/artst/criterions/text_to_speech_loss.py b/artst/criterions/text_to_speech_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b0521eecceb362de08882f12d210fc044eb94691 --- /dev/null +++ b/artst/criterions/text_to_speech_loss.py @@ -0,0 +1,425 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +from dataclasses import dataclass, field + +import torch +from fairseq import metrics, utils +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss +from omegaconf import II +from typing import Any + + +@dataclass +class TexttoSpeechLossConfig(FairseqDataclass): + use_masking: bool = field( + default=True, + metadata={"help": "Whether to use masking in calculation of loss"}, + ) + use_weighted_masking: bool = field( + default=False, + metadata={"help": "Whether to use weighted masking in calculation of loss"}, + ) + loss_type: str = field( + default="L1", + metadata={"help": "How to calc loss"}, + ) + bce_pos_weight: float = field( + default=5.0, + metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"}, + ) + bce_loss_lambda: float = field( + default=1.0, + metadata={"help": "Lambda in bce loss"}, + ) + use_guided_attn_loss: bool = field( + default=False, + metadata={"help": "Whether to use guided attention loss"}, + ) + guided_attn_loss_sigma: float = field( + default=0.4, + metadata={"help": "Sigma in guided attention loss"}, + ) + guided_attn_loss_lambda: float = field( + default=10.0, + metadata={"help": "Lambda in guided attention loss"}, + ) + num_layers_applied_guided_attn: int = field( + default=2, + metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."}, + ) + num_heads_applied_guided_attn: int = field( + default=2, + metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."}, + ) + modules_applied_guided_attn: Any = field( + default=("encoder-decoder",), + metadata={"help": "Module name list to be applied guided attention loss"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + + +class TexttoSpeechLoss(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + use_masking=True, + use_weighted_masking=False, + loss_type="L1", + bce_pos_weight=5.0, + bce_loss_lambda=1.0, + use_guided_attn_loss=False, + guided_attn_loss_sigma=0.4, + guided_attn_loss_lambda=1.0, + num_layers_applied_guided_attn=2, + num_heads_applied_guided_attn=2, + modules_applied_guided_attn=["encoder-decoder"], + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + self.loss_type = loss_type + self.bce_pos_weight = bce_pos_weight + self.bce_loss_lambda = bce_loss_lambda + self.use_guided_attn_loss = use_guided_attn_loss + self.guided_attn_loss_sigma = guided_attn_loss_sigma + self.guided_attn_loss_lambda = guided_attn_loss_lambda + # define loss function + self.criterion = Tacotron2Loss( + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, + bce_pos_weight=bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.num_layers_applied_guided_attn = num_layers_applied_guided_attn + self.num_heads_applied_guided_attn = num_heads_applied_guided_attn + self.modules_applied_guided_attn = modules_applied_guided_attn + if self.use_guided_attn_loss: + self.attn_criterion = GuidedMultiHeadAttentionLoss( + sigma=guided_attn_loss_sigma, + alpha=guided_attn_loss_lambda, + ) + + def forward(self, model, sample): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample) + # sample_size = ( + # sample["target"].size(0) if self.sentence_avg else sample["nframes"] + # ) + sample_size = 1 + logging_output = { + "loss": loss.item(), + "l1_loss": l1_loss.item(), + "l2_loss": l2_loss.item(), + "bce_loss": bce_loss.item(), + "sample_size": 1, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + } + + if enc_dec_attn_loss is not None: + logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item() + + if hasattr(model, 'text_encoder_prenet'): + logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item() + logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() + elif hasattr(model, "speech_encoder_prenet"): + logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item() + else: + if 'task' not in sample: + logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item() + logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item() + + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample): + before_outs, after_outs, logits, attn = net_output + labels = sample["labels"] + ys = sample["dec_target"] + olens = sample["dec_target_lengths"] + ilens = sample["src_lengths"] + + # modifiy mod part of groundtruth + if model.reduction_factor > 1: + olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens]) + olens = olens.new([olen - olen % model.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1 + # labels[:, -1] = 1.0 + else: + olens_in = olens + + # caluculate loss values + l1_loss, l2_loss, bce_loss = self.criterion( + after_outs, before_outs, logits, ys, labels, olens + ) + + # l1_loss = l1_loss / ys.size(2) + # l2_loss = l2_loss / ys.size(2) + + if self.loss_type == "L1": + loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + elif self.loss_type == "L2": + loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss + elif self.loss_type == "L1+L2": + loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss + else: + raise ValueError("unknown --loss-type " + self.loss_type) + + # calculate guided attention loss + enc_dec_attn_loss = None + if self.use_guided_attn_loss: + # calculate the input lengths of encoder, which is determined by encoder prenet + if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1: + ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens]) + else: + ilens_in = ilens + # work for speech to speech model's input + if "task_name" in sample and sample["task_name"] == "s2s": + m = None + if hasattr(model, 'encoder_prenet'): + m = model.encoder_prenet + elif hasattr(model, 'speech_encoder_prenet'): + m = model.speech_encoder_prenet + if m is not None and isinstance(m, SpeechEncoderPrenet): + ilens_in = m.get_src_lengths(ilens_in) + # calculate for encoder-decoder + if "encoder-decoder" in self.modules_applied_guided_attn: + attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn] + att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in) + enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in) + loss = loss + enc_dec_attn_loss + + return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs) + l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs) + bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs) + sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs)) + metrics.log_scalar( + "loss", loss_sum / sample_size, sample_size, 1, round=5 + ) + encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs) + decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs) + ngpu = sum(log.get("ngpu", 0) for log in logging_outputs) + + metrics.log_scalar( + "l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5 + ) + metrics.log_scalar( + "encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5 + ) + metrics.log_scalar( + "decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5 + ) + + if "enc_dec_attn_loss" in logging_outputs[0]: + enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs) + metrics.log_scalar( + "enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8 + ) + + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True + +class Tacotron2Loss(torch.nn.Module): + """Loss function module for Tacotron2.""" + + def __init__( + self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0 + ): + """Initialize Tactoron2 loss module. + + Args: + use_masking (bool): Whether to apply masking + for padded part in loss calculation. + use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + bce_pos_weight (float): Weight of positive sample of stop token. + + """ + super(Tacotron2Loss, self).__init__() + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + # reduction = "none" if self.use_weighted_masking else "sum" + reduction = "none" if self.use_weighted_masking else "mean" + self.l1_criterion = torch.nn.L1Loss(reduction=reduction) + self.mse_criterion = torch.nn.MSELoss(reduction=reduction) + self.bce_criterion = torch.nn.BCEWithLogitsLoss( + reduction=reduction, pos_weight=torch.tensor(bce_pos_weight) + ) + + # NOTE(kan-bayashi): register pre hook function for the compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, after_outs, before_outs, logits, ys, labels, olens): + """Calculate forward propagation. + + Args: + after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). + before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). + logits (Tensor): Batch of stop logits (B, Lmax). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). + olens (LongTensor): Batch of the lengths of each target (B,). + + Returns: + Tensor: L1 loss value. + Tensor: Mean square error loss value. + Tensor: Binary cross entropy loss value. + + """ + # make mask and apply it + if self.use_masking: + masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + ys = ys.masked_select(masks) + after_outs = after_outs.masked_select(masks) + before_outs = before_outs.masked_select(masks) + labels = labels.masked_select(masks[:, :, 0]) + logits = logits.masked_select(masks[:, :, 0]) + + # calculate loss + l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) + mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( + before_outs, ys + ) + bce_loss = self.bce_criterion(logits, labels) + + # make weighted mask and apply it + if self.use_weighted_masking: + masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + weights = masks.float() / masks.sum(dim=1, keepdim=True).float() + out_weights = weights.div(ys.size(0) * ys.size(2)) + logit_weights = weights.div(ys.size(0)) + + # apply weight + l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() + mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() + bce_loss = ( + bce_loss.mul(logit_weights.squeeze(-1)) + .masked_select(masks.squeeze(-1)) + .sum() + ) + + return l1_loss, mse_loss, bce_loss + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Apply pre hook fucntion before loading state dict. + + From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but + old models do not include it and as a result, it causes missing key error when + loading old model parameter. This function solve the issue by adding param in + state dict before loading as a pre hook function + of the `load_state_dict` method. + + """ + key = prefix + "bce_criterion.pos_weight" + if key not in state_dict: + state_dict[key] = self.bce_criterion.pos_weight + +class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): + """Guided attention loss function module for multi head attention. + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + """ + + def forward(self, att_ws, ilens, olens): + """Calculate forward propagation. + Args: + att_ws (Tensor): + Batch of multi head attention weights (B, H, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + Returns: + Tensor: Guided attention loss value. + """ + if self.guided_attn_masks is None: + self.guided_attn_masks = ( + self._make_guided_attention_masks(ilens, olens) + .to(att_ws.device) + .unsqueeze(1) + ) + if self.masks is None: + self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + if self.reset_always: + self._reset_masks() + + return self.alpha * loss + + def _make_guided_attention_masks(self, ilens, olens): + n_batches = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( + ilen, olen, self.sigma + ) + return guided_attn_masks + + @staticmethod + def _make_guided_attention_mask(ilen, olen, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device)) + grid_x, grid_y = grid_x.float(), grid_y.float() + return 1.0 - torch.exp( + -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)) + ) + + @staticmethod + def _make_masks(ilens, olens): + in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in) + out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) diff --git a/artst/data/__init__.py b/artst/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/artst/data/__pycache__/__init__.cpython-38.pyc b/artst/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc07879ce5813f6aa7c3681788dd85cf874f0f71 Binary files /dev/null and b/artst/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/data/__pycache__/multitask_dataset.cpython-38.pyc b/artst/data/__pycache__/multitask_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3242d8e5d49d6a3b45a736a7a11310d5233dde7 Binary files /dev/null and b/artst/data/__pycache__/multitask_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/speech_dataset.cpython-38.pyc b/artst/data/__pycache__/speech_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b621844ac0158846187094405aa17328ddcaa5 Binary files /dev/null and b/artst/data/__pycache__/speech_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc b/artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a560b1812de6f888312a5f4f2e38611eb92959eb Binary files /dev/null and b/artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc b/artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fc71a8c923ec3616fbe44c4691048ea76c92d18 Binary files /dev/null and b/artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc b/artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d71083a27b07c2d6a9cfdba27930db3bff0fbbd Binary files /dev/null and b/artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/text_dataset.cpython-38.pyc b/artst/data/__pycache__/text_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c3582618e4f12d7a810139db06491ae2ab1702f Binary files /dev/null and b/artst/data/__pycache__/text_dataset.cpython-38.pyc differ diff --git a/artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc b/artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4d712b85ec1d76f2e2cb1865c375b8b733934df Binary files /dev/null and b/artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc differ diff --git a/artst/data/multitask_dataset.py b/artst/data/multitask_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..37b6272315e5f53c093c75c9567e64efdac93bdc --- /dev/null +++ b/artst/data/multitask_dataset.py @@ -0,0 +1,263 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import bisect + +import logging +import numpy as np +from torch.utils.data.dataloader import default_collate +from fairseq.data import data_utils + +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + +class MultitaskDataset(FairseqDataset): + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + curr_len = len(e) + r.append(curr_len + s) + s += curr_len + return r + + def __init__(self, datasets, sample_ratios=1, batch_ratio=None): + super(MultitaskDataset, self).__init__() + assert len(datasets) > 0, "datasets should not be an empty iterable" + self.datasets = list(datasets) + if isinstance(sample_ratios, int): + sample_ratios = [sample_ratios] * len(self.datasets) + if batch_ratio is not None: + logger.info('batch ratio is ' + str(batch_ratio)) + self.batch_ratio = batch_ratio + else: + self.batch_ratio = None + else: + logger.info('set sample ratio to ' + str(sample_ratios)) + if batch_ratio is not None: + logger.info('batch ratio is ' + str(batch_ratio)) + self.batch_ratio = batch_ratio + else: + self.batch_ratio = None + self.sample_ratios = sample_ratios + self._ordered_indices = None + self._update_size() + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + sample = self.datasets[dataset_idx][sample_idx] + if isinstance(sample, dict): + sample["dataset_idx"] = dataset_idx + else: + sample = sample + (dataset_idx,) + return sample + + def _update_size(self): + self.cumulative_sizes = self.cumsum(self.datasets) + self.real_sizes = [len(d) for d in self.datasets] + + def _get_dataset_and_sample_index(self, idx: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + sample_idx = sample_idx % self.real_sizes[dataset_idx] + return dataset_idx, sample_idx + + def collater(self, samples, **extra_args): + # For now only supports datasets with same underlying collater implementations + if samples is not None and len(samples) > 0: + if isinstance(samples[0], dict): + dataset_idx = samples[0]["dataset_idx"] + else: + dataset_idx = samples[0][-1] + samples = [sample[:-1] for sample in samples] + else: + dataset_idx = 0 + + if hasattr(self.datasets[dataset_idx], "collater"): + return self.datasets[dataset_idx].collater(samples, **extra_args) + else: + return default_collate(samples, **extra_args) + + def size(self, idx: int): + """ + Return an example's size as a float or tuple. + """ + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx].size(sample_idx) + + def num_tokens(self, index: int): + return np.max(self.size(index)) + + def attr(self, attr: str, index: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) + return getattr(self.datasets[dataset_idx], attr, None) + + @property + def sizes(self): + _dataset_sizes = [] + for ds in self.datasets: + if isinstance(ds.sizes, np.ndarray): + _dataset_sizes.append(ds.sizes) + else: + # Only support underlying dataset with single size array. + assert isinstance(ds.sizes, list) + _dataset_sizes.append(ds.sizes[0]) + return np.concatenate(_dataset_sizes) + + @property + def supports_prefetch(self): + return all(d.supports_prefetch for d in self.datasets) + + def ordered_indices(self): + # ordered_indices = [] + # for i, dataset in enumerate(self.datasets): + # indice = dataset.ordered_indices() + # ordered_indices.append(indice) + if self._ordered_indices is None: + # Call the underlying dataset's ordered_indices() here, so that we + # get the same random ordering as we would have from using the + # underlying sub-datasets directly. + self._ordered_indices = [ + dataset.ordered_indices() + for dataset in self.datasets + ] + return np.arange(len(self)) + + def prefetch(self, indices): + frm = 0 + for to, ds in zip(self.cumulative_sizes, self.datasets): + real_size = len(ds) + if getattr(ds, "supports_prefetch", False): + ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) + frm = to + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not hasattr(self, "max_tokens"): + self.max_tokens = max_tokens + if not hasattr(self, "max_sentences"): + self.max_sentences = max_sentences + if not hasattr(self, "required_batch_size_multiple"): + self.required_batch_size_multiple = required_batch_size_multiple + batch_samplers = [] + for i, dataset in enumerate(self.datasets): + batch_sampler = dataset.batch_by_size( + self._ordered_indices[i], + max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i], + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + if i > 0: + for batch in batch_sampler: + batch += self.cumulative_sizes[i - 1] + if self.sample_ratios[i] != 1.0: + batch_sampler = np.array(batch_sampler) + batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i])) + batch_sampler = list(batch_sampler) + logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i)) + batch_samplers.extend(batch_sampler) + return batch_samplers + + def filter_indices_by_size(self, indices, max_positions): + """ + Filter each sub-dataset independently, then update the round robin to work + on the filtered sub-datasets. + """ + if not hasattr(self, "max_positions"): + self.max_positions = max_positions + ignored_some = False + for i in range(len(self.datasets)): + # ignored = [] + self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size( + self._ordered_indices[i], self.max_positions[i] + ) + if len(ignored) > 0: + ignored_some = True + logger.warning( + f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, " + f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}" + ) + + logger.info('update dataset size') + self._update_size() + + # Since we are modifying in place the _ordered_indices, + # it's not possible anymore to return valid ignored indices. + # Hopefully the extra debug information print above should be enough to debug. + # Ideally we would receive ignore_invalid_inputs so that we could have + # a proper error message. + return (np.arange(len(self)), [0] if ignored_some else []) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) + + def shuffle_batches(self, batches, seed): + logger.info("shuffle batches") + new_batches_fromlist = [] + new_batches_notlist = [] + new_batches = [] + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + for batch in batches: + if isinstance(batch, list): + # np.random.shuffle(batch) + new_batches_fromlist.append(batch) + else: + new_batches_notlist.append(batch) + logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides") + logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides") + logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides") + if len(new_batches_fromlist) == 0: + return new_batches_notlist + st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist)) + logger.info("Get st_ratio " + str(st_ratio)) + last_idx = 0 + for i in range(len(new_batches_fromlist)): + if i == len(new_batches_fromlist) - 1: + new_batches_fromlist[i].extend(new_batches_notlist[last_idx:]) + else: + new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio]) + np.random.shuffle(new_batches_fromlist[i]) + new_batches.extend(new_batches_fromlist[i]) + last_idx = last_idx + st_ratio + logger.info("Finish shuffle") + return new_batches + + def reset_batch_sampler(self): + logger.info("reset batch sampler") + self._ordered_indices = [ + self.datasets[i].ordered_indices() + for i in range(len(self.datasets)) + ] + self.filter_indices_by_size(None, None) + + batch_samplers = self.batch_by_size( + None, + self.max_tokens, + self.max_sentences, + self.required_batch_size_multiple + ) + return batch_samplers diff --git a/artst/data/speech_dataset.py b/artst/data/speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ae03c3859c763518fcc74b87b7de5728278937bb --- /dev/null +++ b/artst/data/speech_dataset.py @@ -0,0 +1,475 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import itertools +import logging +import os +import sys +from typing import Any, List, Optional, Union + +import numpy as np + +import torch +import torch.nn.functional as F +import librosa +from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +): + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + +def add_first_frame_and_remove_last_frame(ys): + ys_in = torch.cat( + [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 + ) + return ys_in + +def load_audio(manifest_path, max_keep, min_keep): + n_long, n_short = 0, 0 + names, inds, sizes, spk_embeds = [], [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 3, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + spk_embeds.append(items[2]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes, spk_embeds + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + + +def logmelfilterbank( + audio, + sampling_rate, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-10, +): + """Compute log-Mel filterbank feature. + (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py) + + Args: + audio (ndarray): Audio signal (T,). + sampling_rate (int): Sampling rate. + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. If set to None, it will be the same as fft_size. + window (str): Window function type. + num_mels (int): Number of mel basis. + fmin (int): Minimum frequency in mel basis calculation. + fmax (int): Maximum frequency in mel basis calculation. + eps (float): Epsilon value to avoid inf in log calculation. + + Returns: + ndarray: Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="reflect") + spc = np.abs(x_stft).T # (#frames, #bins) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sampling_rate / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) + + return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + +class SpeechPretrainDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + random_crop: bool = False, + single_target: bool = False, + reduction_factor: int = 1, + ): + self.audio_root, self.audio_names, inds, tot, self.sizes, self.spk_embeds = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.sample_rate = sample_rate + self.shuffle = shuffle + self.random_crop = random_crop + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, float) + else label_rates + ) + self.store_labels = store_labels + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert label_processors is None or len(label_processors) == self.num_labels + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths( + self.sizes, sample_rate, label_path, label_rate, inds, tot + ) + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + self.reduction_factor = reduction_factor + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, reduction_factor={reduction_factor}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav, cur_sample_rate = sf.read(wav_path) + wav = torch.from_numpy(wav).float() + fbank = logmelfilterbank( + wav.view(-1).cpu().numpy(), 16000 + ) + fbank = torch.from_numpy(fbank).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav, fbank + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def __getitem__(self, index): + wav, fbank = self.get_audio(index) + labels = self.get_labels(index) + spkembs = get_features_or_waveform( + os.path.join(self.audio_root, self.spk_embeds[index]) + ) + spkembs = torch.from_numpy(spkembs).float() + return {"id": index, "source": wav, "target": fbank, "label_list": labels, 'spkembs': spkembs} + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater(self, samples): + # target = max(sizes) -> random_crop not used + # target = max_sample_size -> random_crop used for long + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + + fbanks = [s["target"] for s in samples] + fbank_sizes = [len(s) for s in fbanks] + + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + collated_audios, padding_mask, audio_starts = self.collater_audio( + audios, audio_size + ) + + collated_fbanks = [] + collated_audios_size = [] + for i in range(len(fbanks)): + fbank_start = int(audio_starts[i] / (audio_sizes[i] / fbank_sizes[i])) + fbank_size = int(audio_size / (audio_sizes[i] / fbank_sizes[i])) + fbank_end = min(fbank_start + fbank_size, fbank_sizes[i]) + collated_fbanks.append(fbanks[i][fbank_start : fbank_end]) + collated_audios_size.append(audio_size) + collated_fbanks_size = [len(s) for s in collated_fbanks] + collated_fbanks = _collate_frames(collated_fbanks) + collated_fbanks_size = torch.tensor(collated_fbanks_size, dtype=torch.long) + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor] + collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size]) + else: + collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size + + prev_output_tokens = torch.cat( + [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1 + ) + + # make labels for stop prediction + labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1)) + for i, l in enumerate(fbank_sizes): + labels[i, l - 1 :] = 1.0 + + spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True) + + targets_by_label = [ + [s["label_list"][i] for s in samples] for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label( + targets_by_label, audio_size, audio_starts + ) + + net_input = { + "source": collated_audios, + "padding_mask": padding_mask, + "prev_output_tokens": prev_output_tokens, + "spkembs": spkembs, + "tgt_lengths": collated_fbanks_size_in, + } + + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + "labels": labels, + "dec_target": collated_fbanks, + "dec_target_lengths": collated_fbanks_size, + "src_lengths": collated_audios_size, + "task_name": 'speech_pretrain', + } + + if self.single_target: + batch["target_lengths"] = lengths_list[0] + batch["ntokens"] = ntokens_list[0] + batch["target"] = targets_list[0] + else: + batch["target_lengths_list"] = lengths_list + batch["ntokens_list"] = ntokens_list + batch["target_list"] = targets_list + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + return collated_audios, padding_mask, audio_starts + + def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): + assert label_rate > 0 + s2f = label_rate / self.sample_rate + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1.0: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/artst/data/speech_to_class_dataset.py b/artst/data/speech_to_class_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3c8e964ad05bc5d7fd85933f6f523e5dba5aea --- /dev/null +++ b/artst/data/speech_to_class_dataset.py @@ -0,0 +1,260 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +import os +from typing import Any, List, Optional + +import numpy as np + +import torch +import torch.nn.functional as F +from fairseq.data import data_utils, Dictionary +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + + +def load_audio(manifest_path, max_keep, min_keep): + """manifest tsv: wav_path, wav_nframe, wav_class + + Args + manifest_path: str + max_keep: int + min_keep: int + + Return + root, names, inds, tot, sizes, classes + """ + n_long, n_short = 0, 0 + names, inds, sizes, classes = [], [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) >= 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + if len(items) > 2: + classes.append(items[2]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + if len(classes) == 0: + logger.warn("no classes loaded only if inference") + return root, names, inds, tot, sizes, classes + + +def sample_from_feature(x: np.ndarray, max_segment_length: int = 300): + """Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance. + + Args: + x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform + max_segment_length (int, optional): maximum segment length. Defaults to 400. + + Returns: + np.ndarray: segmented features + """ + if len(x) <= max_segment_length: + return x + start = np.random.randint(0, x.shape[0] - max_segment_length) + return x[start: start + max_segment_length] + + +class SpeechToClassDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + shuffle: bool = True, + normalize: bool = False, + tgt_dict: Optional[Dictionary] = None, + max_length: Optional[int] = None + ): + self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.sample_rate = sample_rate + self.shuffle = shuffle + + self.label_processors = label_processors + + self.normalize = normalize + self.tgt_dict = tgt_dict + self.max_length = max_length + logger.info( + f"max_length={max_length}, normalize={normalize}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav, cur_sample_rate = sf.read(wav_path) + if self.max_length is not None: + wav = sample_from_feature(wav, self.max_length) + wav = torch.from_numpy(wav).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav + + def get_label(self, index): + label = self.wav_classes[index] + + if self.label_processors is not None: + label = self.label_processors(label) + return label + + def __getitem__(self, index): + wav = self.get_audio(index) + label = None + if len(self.wav_classes) == len(self.audio_names): + label = self.get_label(index) + return {"id": index, "source": wav, "label": label} + + def __len__(self): + return len(self.wav_sizes) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + + audio_size = max(audio_sizes) + collated_audios, padding_mask = self.collater_audio( + audios, audio_size + ) + + decoder_label = None + decoder_target = None + decoder_target_lengths = None + if samples[0]["label"] is not None: + targets_by_label = [ + [s["label"] for s in samples] + ] + targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label) + + decoder_label = [ + (targets_list[0][i, :lengths_list[0][i]]).long() + for i in range(targets_list[0].size(0)) + ] + + decoder_target = data_utils.collate_tokens( + decoder_label, + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + decoder_target_lengths = torch.tensor( + [x.size(0) for x in decoder_label], dtype=torch.long + ) + prev_output_tokens = data_utils.collate_tokens( + [torch.LongTensor([-1]) for _ in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, + ) + + net_input = { + "source": collated_audios, + "padding_mask": padding_mask, + "prev_output_tokens": prev_output_tokens, + "task_name": "s2c", + } + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + "target": decoder_target, + "target_lengths": decoder_target_lengths, + "task_name": "s2c", + "ntokens": len(samples), + } + + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + ) + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + raise Exception("Diff should not be larger than 0") + return collated_audios, padding_mask + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, [self.tgt_dict.pad()]) + for targets, pad in itr: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + return self.wav_sizes[index] + + @property + def sizes(self): + return np.array(self.wav_sizes) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.wav_sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/artst/data/speech_to_speech_dataset.py b/artst/data/speech_to_speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2af87106a7d24a718bf52480a900ad8ce566001f --- /dev/null +++ b/artst/data/speech_to_speech_dataset.py @@ -0,0 +1,280 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +import os +from typing import Any, List, Optional + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +): + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + +def load_audio(manifest_path, max_keep, min_keep): + """manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb""" + n_long, n_short = 0, 0 + src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) >= 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + src_names.append(items[0]) + tgt_names.append(items[2]) + tgt_sizes.append(items[3]) + spk_embeds.append(items[4]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds + + +def logmelfilterbank( + audio, + sampling_rate, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-10, +): + """Compute log-Mel filterbank feature. + (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py) + + Args: + audio (ndarray): Audio signal (T,). + sampling_rate (int): Sampling rate. + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. If set to None, it will be the same as fft_size. + window (str): Window function type. + num_mels (int): Number of mel basis. + fmin (int): Minimum frequency in mel basis calculation. + fmax (int): Maximum frequency in mel basis calculation. + eps (float): Epsilon value to avoid inf in log calculation. + + Returns: + ndarray: Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="reflect") + spc = np.abs(x_stft).T # (#frames, #bins) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sampling_rate / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) + + return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + +class SpeechToSpeechDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + shuffle: bool = True, + normalize: bool = False, + reduction_factor: int = 1, + ): + self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.sample_rate = sample_rate + self.shuffle = shuffle + + self.normalize = normalize + self.reduction_factor = reduction_factor + logger.info( + f"reduction_factor={reduction_factor}, normalize={normalize}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_fbank = [] + for name in [self.audio_names[index], self.tgt_audios[index]]: + wav_path = os.path.join(self.audio_root, name) + wav, cur_sample_rate = sf.read(wav_path) + wav = torch.from_numpy(wav).float() + fbank = logmelfilterbank( + wav.view(-1).cpu().numpy(), 16000 + ) + fbank = torch.from_numpy(fbank).float() + wav = self.postprocess(wav, cur_sample_rate) + wav_fbank.append(wav) + wav_fbank.append(fbank) + src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank + return src_wav, src_fbank, tgt_wav, tgt_fbank + + def __getitem__(self, index): + src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index) + spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index])) + spkembs = torch.from_numpy(spkembs).float() + name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav" + return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]} + + def __len__(self): + return len(self.wav_sizes) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + + audio_size = max(audio_sizes) + collated_audios, padding_mask = self.collater_audio( + audios, audio_size + ) + + fbanks = [s["target"] for s in samples] + fbank_sizes = [len(s) for s in fbanks] + + collated_fbanks = _collate_frames(fbanks) + collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long) + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor] + collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size]) + else: + collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size + + prev_output_tokens = torch.cat( + [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1 + ) + + # make labels for stop prediction + labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1)) + for i, l in enumerate(fbank_sizes): + labels[i, l - 1 :] = 1.0 + + spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True) + + net_input = { + "source": collated_audios, + "padding_mask": padding_mask, + "prev_output_tokens": prev_output_tokens, + "tgt_lengths": collated_fbanks_size_in, + "spkembs": spkembs, + "task_name": "s2s", + } + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "name": [s["audio_name"] for s in samples], + "tgt_name": [s["tgt_name"] for s in samples], + "net_input": net_input, + "labels": labels, + "dec_target": collated_fbanks, + "dec_target_lengths": collated_fbanks_size, + "src_lengths": torch.LongTensor(audio_sizes), + "task_name": "s2s", + "ntokens": sum(audio_sizes), + "target": collated_fbanks, + } + + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + ) + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + raise Exception("Diff should not be larger than 0") + return collated_audios, padding_mask + + + def num_tokens(self, index): + return self.wav_sizes[index] + + def size(self, index): + return self.wav_sizes[index], self.tgt_sizes[index] + + @property + def sizes(self): + return np.array(self.wav_sizes) + + @property + def can_reuse_epoch_itr_across_epochs(self): + """No cache dataset if dataset is large-scale. Cache dataset for small dataset.""" + return True + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.wav_sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/artst/data/speech_to_text_dataset.py b/artst/data/speech_to_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..10249f25d9b51aadbe6d1464f0ec9ed8aaa03db5 --- /dev/null +++ b/artst/data/speech_to_text_dataset.py @@ -0,0 +1,298 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import itertools +import logging +import os +import mmap +from typing import Any, List, Optional + +import numpy as np + +import torch +torch.set_printoptions(profile="full") +import torch.nn.functional as F +from fairseq.data import data_utils, Dictionary +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + + +def load_audio(manifest_path, max_keep, min_keep): + n_long, n_short = 0, 0 + names, inds, sizes = [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) >= 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + # Hawau: + # changed line length reading as it's incorrect + code_lengths = [len(line.encode("utf-8")) for line in f] #original + # code_lengths = [len(line) for line in f] #fix + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +class SpeechToTextDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + shuffle: bool = True, + normalize: bool = False, + store_labels: bool = True, + tgt_dict: Optional[Dictionary] = None, + tokenizer = None, + ): + self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + + self.sample_rate = sample_rate + self.shuffle = shuffle + self.tgt_dict = tgt_dict + self.tokenizer = tokenizer + + self.num_labels = len(label_paths) + self.label_processors = label_processors + self.store_labels = store_labels + + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + logger.info(f"label_list: {self.label_list}") + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + # logger.info(f"label_offsets_list: {self.label_offsets_list}") + assert label_processors is None or len(label_processors) == self.num_labels + + self.normalize = normalize + logger.info( + f"normalize={normalize}" + ) + + def get_audio(self, index): + import soundfile as sf + # Hawau: + # logger.info(f"loaded_audio: {self.audio_names[index]}") + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav, cur_sample_rate = sf.read(wav_path) + wav = torch.from_numpy(wav).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + # list slicing method + # with open(self.label_paths[label_idx]) as f: + # offset_s, offset_e = self.label_offsets_list[label_idx][index] + # # Hawau: + # # f.seek(offset_s) + # # label = f.read(offset_e - offset_s) + # label = f.read()[offset_s : offset_e] + # Hawau: + # mmap method + with open(self.label_paths[label_idx], encoding='utf-8') as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + label = mm[offset_s:offset_e].decode("utf-8") + + + # Hawau: + # logger.info(f"loaded_label: {label}") + if self.tokenizer is not None: + label = self.tokenizer.encode(label) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + # logger.info(f"processed_label: {label}") + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def __getitem__(self, index): + wav = self.get_audio(index) + labels = self.get_labels(index) + return {"id": index, "source": wav, "label_list": labels} + + def __len__(self): + return len(self.wav_sizes) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + + audio_size = max(audio_sizes) + collated_audios, padding_mask = self.collater_audio( + audios, audio_size + ) + + targets_by_label = [ + [s["label_list"][i] for s in samples] for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label) + + # Hawau: + # logger.info(f'targets_list: {targets_list}') + + + decoder_label = [ + torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() + for i in range(targets_list[0].size(0)) + ] + + decoder_target = data_utils.collate_tokens( + decoder_label, + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + decoder_target_lengths = torch.tensor( + [x.size(0) for x in decoder_label], dtype=torch.long + ) + prev_output_tokens = data_utils.collate_tokens( + decoder_label, + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, + ) + + net_input = { + "source": collated_audios, + "padding_mask": padding_mask, + "prev_output_tokens": prev_output_tokens, + "task_name": "s2t", + } + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + "target": decoder_target, + "target_lengths": decoder_target_lengths, + "task_name": "s2t", + "ntokens": ntokens_list[0] + } + + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + ) + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + raise Exception("Diff should not be larger than 0") + return collated_audios, padding_mask + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, [self.tgt_dict.pad()]) + + for targets, pad in itr: + # Hawau: + # logger.info(f'targets: {targets}') + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + return self.wav_sizes[index] + + @property + def sizes(self): + return np.array(self.wav_sizes) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.wav_sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/artst/data/text_dataset.py b/artst/data/text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c28e437bd28191f3269d86876f0ad2c82c048b --- /dev/null +++ b/artst/data/text_dataset.py @@ -0,0 +1,474 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, data_utils + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, + ) + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + ntokens = sum(len(s["target"]) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, + "task_name": 'text_pretrain', + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + + return batch + + +class TextPretrainDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for BART dataset. + + Args: + dataset (TokenBlockDataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + mask_idx (int): dictionary index used for masked token + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + args: argparse arguments. + """ + + def __init__( + self, + dataset, + sizes, + vocab, + mask_idx, + mask_whole_words, + shuffle, + seed, + args, + eos=None, + item_transform_func=None, + iid_noise_target=False, + uni_mask_idxs=None, + ): + self.dataset = dataset + + self.sizes = sizes + + self.vocab = vocab + self.shuffle = shuffle + self.seed = seed + if iid_noise_target: + assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs" + self.iid_noise_target = iid_noise_target + self.uni_mask_idxs = uni_mask_idxs + self.mask_idx = mask_idx + self.mask_whole_word = mask_whole_words + self.mask_ratio = args.mask + self.random_ratio = args.mask_random + self.insert_ratio = args.insert + self.rotate_ratio = args.rotate + self.permute_sentence_ratio = args.permute_sentences + self.eos = eos if eos is not None else vocab.eos() + self.item_transform_func = item_transform_func + + if args.bpe != "gpt2": + self.full_stop_index = self.vocab.eos() + else: + assert args.bpe == "gpt2" + self.full_stop_index = self.vocab.index("13") + + self.replace_length = args.replace_length + if self.replace_length not in [-1, 0, 1]: + raise ValueError(f"invalid arg: replace_length={self.replace_length}") + if args.mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={args.mask_length}") + if args.mask_length == "subword" and args.replace_length not in [0, 1]: + raise ValueError(f"if using subwords, use replace-length=1 or 0") + + self.mask_span_distribution = None + if args.mask_length == "span-poisson": + _lambda = args.poisson_lambda + + lambda_to_the_k = 1 + e_to_the_minus_lambda = math.exp(-_lambda) + k_factorial = 1 + ps = [] + for k in range(0, 128): + ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) + lambda_to_the_k *= _lambda + k_factorial *= k + 1 + if ps[-1] < 0.0000001: + break + ps = torch.FloatTensor(ps) + self.mask_span_distribution = torch.distributions.Categorical(ps) + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + tokens = self.dataset[index] + assert tokens[-1] == self.eos + source, target = tokens, tokens.clone() + + if self.permute_sentence_ratio > 0.0: + source = self.permute_sentences(source, self.permute_sentence_ratio) + + if self.mask_ratio > 0: + source, new_target = self.add_whole_word_mask(source, self.mask_ratio) + if new_target is not None: + target = new_target + + if self.insert_ratio > 0: + source = self.add_insertion_noise(source, self.insert_ratio) + + if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: + source = self.add_rolling_noise(source) + # there can additional changes to make: + if self.item_transform_func is not None: + source, target = self.item_transform_func(source, target) + + assert (source >= 0).all() + assert (source[1:-1] >= 1).all() + assert (source <= len(self.vocab)).all() + assert source[0] == self.vocab.bos() + assert source[-1] == self.eos + return { + "id": index, + "source": source, + "target": target, + } + + def __len__(self): + return len(self.dataset) + + def permute_sentences(self, source, p=1.0): + full_stops = source == self.full_stop_index + # Pretend it ends with a full stop so last span is a sentence + full_stops[-2] = 1 + + # Tokens that are full stops, where the previous token is not + sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 + result = source.clone() + + num_sentences = sentence_ends.size(0) + num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) + substitutions = torch.randperm(num_sentences)[:num_to_permute] + ordering = torch.arange(0, num_sentences) + ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] + + # Ignore at start + index = 1 + for i in ordering: + sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] + result[index : index + sentence.size(0)] = sentence + index += sentence.size(0) + return result + + def word_starts(self, source): + if self.mask_whole_word is not None: + is_word_start = self.mask_whole_word.gather(0, source) + else: + is_word_start = torch.ones(source.size()) + is_word_start[0] = 0 + is_word_start[-1] = 0 + return is_word_start + + def add_whole_word_mask(self, source, p): + source_ori = source.clone() + is_word_start = self.word_starts(source) + num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) + num_inserts = 0 + if num_to_mask == 0: + return source + + if self.mask_span_distribution is not None: + lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) + + # Make sure we have enough to mask + cum_length = torch.cumsum(lengths, 0) + while cum_length[-1] < num_to_mask: + lengths = torch.cat( + [ + lengths, + self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), + ], + dim=0, + ) + cum_length = torch.cumsum(lengths, 0) + + # Trim to masking budget + i = 0 + while cum_length[i] < num_to_mask: + i += 1 + lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) + num_to_mask = i + 1 + lengths = lengths[:num_to_mask] + + # Handle 0-length mask (inserts) separately + lengths = lengths[lengths > 0] + num_inserts = num_to_mask - lengths.size(0) + num_to_mask -= num_inserts + if num_to_mask == 0: + return self.add_insertion_noise(source, num_inserts / source.size(0)) + + assert (lengths > 0).all() + else: + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + word_starts = is_word_start.nonzero(as_tuple=False) + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc + if self.replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + if self.mask_span_distribution is not None: + assert len(lengths.size()) == 1 + assert lengths.size() == indices.size() + lengths -= 1 + while indices.size(0) > 0: + assert lengths.size() == indices.size() + lengths -= is_word_start[indices + 1].long() + uncompleted = lengths >= 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + lengths = lengths[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + else: + # A bit faster when all lengths are 1 + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + assert source_length - 1 not in indices + + if not self.iid_noise_target: + source = source[to_keep] + target = None + else: + ## Prepare source + source_mask_idx = (source == self.mask_idx).nonzero().view(-1) + source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)] + source = source[to_keep] + + ## Prepare target + to_keep[source_mask_idx] = 0 + + # source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...] + source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0)) + # target: source_length + mask_length + target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0)) + # target: [0, 0, 0, X, 0, 0, Y, ....] + target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)] + + target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0)) + + # Copy original value to target and target_to_keep + target_to_keep[target == 0] = to_keep + target_to_keep[-1] = 0 + target[target == 0] = source_ori + + target = target[~target_to_keep] + + if num_inserts > 0: + source = self.add_insertion_noise(source, num_inserts / source.size(0)) + + return source, target + + def add_permuted_noise(self, tokens, p): + num_words = len(tokens) + num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) + substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 + tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] + return tokens + + def add_rolling_noise(self, tokens): + offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) + tokens = torch.cat( + (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), + dim=0, + ) + return tokens + + def add_insertion_noise(self, tokens, p): + if p == 0.0: + return tokens + + num_tokens = len(tokens) + n = int(math.ceil(num_tokens * p)) + + noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 + noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) + noise_mask[noise_indices] = 1 + result = torch.LongTensor(n + len(tokens)).fill_(-1) + + num_random = int(math.ceil(n * self.random_ratio)) + result[noise_indices[num_random:]] = self.mask_idx + result[noise_indices[:num_random]] = torch.randint( + low=1, high=len(self.vocab), size=(num_random,) + ) + + result[~noise_mask] = tokens + + assert (result >= 0).all() + return result + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate( + samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.sizes[indices], kind="mergesort")] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, "supports_prefetch") + and self.src.supports_prefetch + and hasattr(self.tgt, "supports_prefetch") + and self.tgt.supports_prefetch + ) diff --git a/artst/data/text_to_speech_dataset.py b/artst/data/text_to_speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0731fbfc5d8e2309376bff4efbbe8f2ba00803a2 --- /dev/null +++ b/artst/data/text_to_speech_dataset.py @@ -0,0 +1,344 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import itertools +import logging +import os +from typing import Any, List, Optional +import mmap + +import numpy as np + +import torch +import torch.nn.functional as F +import librosa +from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform +from fairseq.data import data_utils, Dictionary +from fairseq.data.fairseq_dataset import FairseqDataset + + +logger = logging.getLogger(__name__) + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +): + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + +def load_audio(manifest_path, max_keep, min_keep): + n_long, n_short = 0, 0 + names, inds, sizes, spk_embeds = [], [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 3, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + spk_embeds.append(items[2]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes, spk_embeds + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path, encoding='utf-8') as f: + code_lengths = [len(line.encode("utf-8")) for line in f] #changed as in speech_to_text_dataset.py + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def logmelfilterbank( + audio, + sampling_rate, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-10, +): + """Compute log-Mel filterbank feature. + (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py) + + Args: + audio (ndarray): Audio signal (T,). + sampling_rate (int): Sampling rate. + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. If set to None, it will be the same as fft_size. + window (str): Window function type. + num_mels (int): Number of mel basis. + fmin (int): Minimum frequency in mel basis calculation. + fmax (int): Maximum frequency in mel basis calculation. + eps (float): Epsilon value to avoid inf in log calculation. + + Returns: + ndarray: Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="reflect") + spc = np.abs(x_stft).T # (#frames, #bins) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sampling_rate / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) + + return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + + +class TextToSpeechDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + shuffle: bool = True, + normalize: bool = False, + store_labels: bool = True, + src_dict: Optional[Dictionary] = None, + tokenizer = None, + reduction_factor: int = 1, + inference: bool = False, + ): + + self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.spk_embeds = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.inference = inference + + self.sample_rate = sample_rate + self.shuffle = shuffle + self.src_dict = src_dict + self.tokenizer = tokenizer + + self.num_labels = len(label_paths) + self.label_processors = label_processors + self.store_labels = store_labels + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert label_processors is None or len(label_processors) == self.num_labels + + self.normalize = normalize + self.reduction_factor = reduction_factor + logger.info( + f"reduction_factor={reduction_factor}, normalize={normalize}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav, cur_sample_rate = sf.read(wav_path) + wav = torch.from_numpy(wav).float() + fbank = logmelfilterbank( + wav.view(-1).cpu().numpy(), 16000 + ) + fbank = torch.from_numpy(fbank).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav, fbank + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + # with open(self.label_paths[label_idx]) as f: + # offset_s, offset_e = self.label_offsets_list[label_idx][index] + # f.seek(offset_s) + # label = f.read(offset_e - offset_s) + + # Hawau: + # mmap method + with open(self.label_paths[label_idx], encoding='utf-8') as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + label = mm[offset_s:offset_e].decode("utf-8") + + + if self.tokenizer is not None: + label = self.tokenizer.encode(label) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def __getitem__(self, index): + wav, fbank = self.get_audio(index) + labels = self.get_labels(index) + spkembs = get_features_or_waveform( + os.path.join(self.audio_root, self.spk_embeds[index]) + ) + spkembs = torch.from_numpy(spkembs).float() + + return {"id": index, "source": labels, "target": fbank, "spkembs": spkembs, "audio_name": self.audio_names[index]} + + + def __len__(self): + return len(self.wav_sizes) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + fbanks = [s["target"] for s in samples] + fbank_sizes = [len(s) for s in fbanks] + + collated_fbanks = _collate_frames(fbanks) + collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long) + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor] + collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size]) + else: + collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size + + prev_output_tokens = torch.cat( + [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1 + ) + + # make labels for stop prediction + labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1)) + for i, l in enumerate(fbank_sizes): + labels[i, l - 1 :] = 1.0 + + spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True) + + sources_by_label = [ + [s["source"][i] for s in samples] for i in range(self.num_labels) + ] + sources_list, lengths_list, ntokens_list = self.collater_label(sources_by_label) + + net_input = { + "src_tokens": sources_list[0], + "src_lengths": lengths_list[0], + "prev_output_tokens": prev_output_tokens, + "tgt_lengths": collated_fbanks_size_in, + "spkembs": spkembs, + "task_name": "t2s", + } + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "name": [s["audio_name"] for s in samples], + "net_input": net_input, + "labels": labels, + "dec_target": collated_fbanks, + "dec_target_lengths": collated_fbanks_size, + "src_lengths": lengths_list[0], + "task_name": "t2s", + "ntokens": ntokens_list[0], + "target": collated_fbanks, + } + + return batch + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, [self.src_dict.pad()]) + for targets, pad in itr: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + return self.wav_sizes[index] + + @property + def sizes(self): + return np.array(self.wav_sizes) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.wav_sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/artst/models/__init__.py b/artst/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b159d415b3caec472508ac1dfcca9ac9e65fb80 --- /dev/null +++ b/artst/models/__init__.py @@ -0,0 +1,2 @@ +from .artst import * # noqa +from .t5_transformer_lm import * # noqa diff --git a/artst/models/__pycache__/__init__.cpython-38.pyc b/artst/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aef5e38df5aee59bd58f44627b7a7301fe6f63c Binary files /dev/null and b/artst/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/models/__pycache__/artst.cpython-38.pyc b/artst/models/__pycache__/artst.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e07d84a798352c7d1f0a630927f40cd1c594eca1 Binary files /dev/null and b/artst/models/__pycache__/artst.cpython-38.pyc differ diff --git a/artst/models/__pycache__/speecht5.cpython-38.pyc b/artst/models/__pycache__/speecht5.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f40074edfcbdf1f2481fdcc0dba46825639ec6 Binary files /dev/null and b/artst/models/__pycache__/speecht5.cpython-38.pyc differ diff --git a/artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc b/artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7080e3548267213bc2d90ca992a2489a466b498 Binary files /dev/null and b/artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc differ diff --git a/artst/models/artst.py b/artst/models/artst.py new file mode 100644 index 0000000000000000000000000000000000000000..a00bd3781ce35d5d9de4386dd786c57ede652cb9 --- /dev/null +++ b/artst/models/artst.py @@ -0,0 +1,1448 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +from ast import literal_eval +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from .modules.text_encoder_prenet import TextEncoderPrenet +from .modules.text_decoder_prenet import TextDecoderPrenet +from .modules.text_decoder_postnet import TextDecoderPostnet +from .modules.speech_encoder_prenet import SpeechEncoderPrenet +from .modules.speech_encoder_postnet import SpeechEncoderPostnet +from .modules.speech_decoder_prenet import SpeechDecoderPrenet +from .modules.speech_decoder_postnet import SpeechDecoderPostnet +from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet +from .modules.encoder import TransformerEncoder +from .modules.decoder import TransformerDecoder +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.models.transformer import Embedding +from fairseq.modules import ( + GumbelVectorQuantizer, +) +from torch import Tensor + + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_TEXT_POSITIONS = 450 +DEFAULT_MAX_SPEECH_POSITIONS = 4000 + + +@register_model("artst_transformer") +class ArTSTTransformerModel(FairseqEncoderDecoderModel): + """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for + speech-to-text tasks. The Transformer encoder/decoder remains the same. + A trainable input subsampler is prepended to the Transformer encoder to + project inputs into the encoder dimension as well as downsample input + sequence for computational efficiency.""" + + def __init__( + self, + args, + encoder, decoder, + text_encoder_prenet, speech_encoder_prenet, + text_decoder_prenet, speech_decoder_prenet, + text_decoder_postnet, speech_decoder_postnet, + speaker_decoder_postnet, speech_encoder_postnet, + ): + super().__init__(encoder, decoder) + + self.encoder = encoder + self.decoder = decoder + + self.text_encoder_prenet = text_encoder_prenet + self.speech_encoder_prenet = speech_encoder_prenet + + self.text_decoder_prenet = text_decoder_prenet + self.speech_decoder_prenet = speech_decoder_prenet + + self.text_decoder_postnet = text_decoder_postnet + self.speech_decoder_postnet = speech_decoder_postnet + self.speaker_decoder_postnet = speaker_decoder_postnet + + self.hubert_layer = speech_encoder_postnet + + self.reduction_factor = args.reduction_factor + self.spk_embed_dim = args.spk_embed_dim + + # define projection layer + self.spk_embed_integration_type = args.spk_embed_integration_type + if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre': + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim) + else: + self.projection = torch.nn.Linear( + args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim + ) + + # Hawau: here we can add language embedding integration + + self.use_codebook = args.use_codebook + self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob + if self.use_codebook: + vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim + self.quantizer = GumbelVectorQuantizer( + dim=args.encoder_embed_dim, + num_vars=args.latent_vars, + temp=args.latent_temp, + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=args.quantizer_depth, + weight_proj_factor=args.quantizer_factor, + ) + + self.num_updates = 0 + + # # Follow BERT's random weight initialization (for BART) + if args.bert_init: + self.apply(init_bert_params) + self.args = args + self.prune_modules(args.modules_filter) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # Transformer + parser.add_argument( + "--activation-fn", + type=str, + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--reduction-factor", + type=int, + help="reduction factor for decoder", + ) + parser.add_argument( + "--spk-embed-dim", + type=int, + help="speaker embedding dimension", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + '--freeze-encoder-updates', + type=int, + help='number of steps to freeze encoder before finetune' + ) + parser.add_argument( + '--freeze-decoder-updates', + type=int, + help='number of steps to freeze decoder before finetune' + ) + parser.add_argument( + '--no-freeze-encoder-layer', + type=str, + help='which encoder layer not freeze during finetune' + ) + parser.add_argument( + "--share-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--share-ctc-embed", + action="store_true", + help="share ctc embed and decoder embed", + ) + parser.add_argument( + "--encoder-sliding-window-attn", + default=None, + type=int, + help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20", + ) + + # Convolutional subsampler + parser.add_argument( + "--encoder-speech-prenet", + default="conv", + type=str, + choices=["conv", "linear"], + help="The type of encoder speech prenet, e.g., conv or linear." + ) + parser.add_argument( + "--conv-kernel-sizes", + default="5,5", + type=str, + help="The layer of convolution of encoder speech prenet." + ) + parser.add_argument( + "--conv-channels", + default=1024, + type=int, + help="The channels of encoder speech prenet." + ) + parser.add_argument( + "--subsample-stride", + default="2,2", + type=str, + help="The subsample stride for conv1dsubsample." + ) + parser.add_argument( + "--spk-embed-integration-type", + type=str, + choices=["pre", "add"], + help="speaker embedding integration type" + ) + parser.add_argument( + "--dprenet-dropout-rate", + default=0.5, + type=float, + help="The dropout rate of decoder speech prenet." + ) + + ## SE + parser.add_argument( + "--se-predict", + default=None, + choices=["masking", "target", "delta"], + help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs." + + "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs", + ) + parser.add_argument( + "--se-decoder-input", + type=str, + default="previous_target", + choices=["previous_target", "source"], + ) + + ## SID + parser.add_argument( + "--modules-filter", + default=None, + type=str, + help="Remove unused modules for, e.g., SID.", + ) + parser.add_argument( + "--sid-pad-prenet", + action="store_true", + help="If set, the size of text dictionary is as small as for token.", + ) + parser.add_argument( + "--encoder-attn-branch", + type=str, + default="identity,full", + help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'", + ) + parser.add_argument( + "--encoder-block-branch", + type=str, + help="average the output of encoder, e.g., '4,5,6'", + ) + parser.add_argument( + "--sid-encoder-cls", + default=None, + choices=["encoder"], + help="If set, add cls vector to the encoder input, e.g., constant vector.", + ) + parser.add_argument( + "--sid-shuffle-encoder-input", + action="store_true", + help="If set, shuffle encoder input in time.", + ) + parser.add_argument( + "--sid-decoder-speaker", + action="store_true", + help="If set, apply speaker decoder as transformer decoder.", + ) + parser.add_argument( + "--sid-decoder-attn-dim", + default=128, + type=int, + help="Attention dimension in attensive statistics pooling of speaker decoder.", + ) + parser.add_argument( + "--sid-t5-postnet", + action="store_true", + help="If set, apply TextDecoderPostnet as speaker classification.", + ) + parser.add_argument( + "--sid-embed-dim", + default=128, + type=int, + help="Embedding dimension in speaker postnet for speaker identification if embed postnet.", + ) + parser.add_argument( + "--sid-pooling-layer", + default="decoder", + type=str, + choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"], + help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.", + ) + parser.add_argument( + "--sid-no-pooling-bn", + action="store_true", + help="If set, not attention batchnorm.", + ) + parser.add_argument( + "--sid-no-embed-postnet", + action="store_true", + help="If set, no layer between decoder output and classification layer.", + ) + parser.add_argument( + "--sid-normalize-postnet", + action="store_true", + help="If set, normalize input and weight in postnet/classifier.", + ) + parser.add_argument( + "--sid-softmax-type", + default="softmax", + choices=["softmax", "amsoftmax", "aamsoftmax"], + help="If using amsoftmax or aamsoftmax, the target should be given.", + ) + parser.add_argument( + "--softmax-scale", + default=1.0, + type=float, + help="Scale for AMSoftmax or AAMSoftmax.", + ) + parser.add_argument( + "--softmax-margin", + default=0.0, + type=float, + help="Margin for AMSoftmax or AAMSoftmax.", + ) + parser.add_argument( + "--softmax-easy-margin", + action="store_true", + help="Enable easy margin for AAMSoftmax.", + ) + parser.add_argument( + "--encoder-layerdrop", + type=float, + metavar="D", + help="LayerDrop probability for encoder", + ) + parser.add_argument( + "--decoder-layerdrop", + type=float, + metavar="D", + help="LayerDrop probability for decoder", + ) + + ## Hubert + parser.add_argument( + '--feature-grad-mult', + type=float, + help='multiply feature extractor var grads by this' + ) + parser.add_argument( + '--logit-temp', + type=float, + help='temperature to divide logits by' + ) + parser.add_argument( + '--final-dim', + type=int, + help="project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + ) + + # mask + parser.add_argument( + '--hubert-mask-length', + type=int, + help='mask length' + ) + parser.add_argument( + '--mask-prob', + type=float, + help='probability of replacing a token with mask' + ) + parser.add_argument( + "--mask-selection", + choices=["static", "uniform", "normal", "poisson"], + help="how to choose mask length", + ) + parser.add_argument( + '--mask-other', + type=float, + help="secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + ) + parser.add_argument( + '--mask-min-space', + type=int, + help='min space between spans (if no overlap is enabled)' + ) + + # channel masking + parser.add_argument( + '--mask-channel-length', + type=int, + help='length of the mask for features (channels)' + ) + parser.add_argument( + '--mask-channel-prob', + type=float, + help="probability of replacing a feature with 0" + ) + parser.add_argument( + "--mask-channel-selection", + choices=["static", "uniform", "normal", "poisson"], + help="how to choose mask length for channel masking", + ) + parser.add_argument( + '--mask-channel-other', + type=float, + help="secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + ) + parser.add_argument( + '--mask-channel-min-space', + type=int, + help='min space between spans (if no overlap is enabled)' + ) + + # abs positional embeddings + parser.add_argument( + '--conv-pos', + type=int, + help='number of filters for convolutional positional embeddings' + ) + parser.add_argument( + '--conv-pos-groups', + type=int, + help='number of groups for convolutional positional embedding' + ) + + # codebook related + parser.add_argument( + "--use-codebook", + action="store_true", + help="whether to use codebook", + ) + parser.add_argument( + "--codebook-prob", + type=float, + help="probability to use codebook", + ) + parser.add_argument( + "--latent-vars", + type=int, + help="number of latent variables V in each group of the codebook", + ) + parser.add_argument( + "--latent-groups", + type=int, + help="number of groups G of latent variables in the codebook", + ) + parser.add_argument( + "--latent-dim", + type=int, + help="if > 0, uses this dimensionality for latent variables. " + "otherwise uses final_dim / latent_groups", + ) + parser.add_argument( + "--latent-temp", + type=literal_eval, + help="temperature for latent variable sampling. " + "can be tuple of 3 values (start, end, decay)", + ) + parser.add_argument( + "--quantizer-depth", + type=int, + help="number of quantizer layers", + ) + parser.add_argument( + "--quantizer-factor", + type=int, + help="number of quantizer layers", + ) + parser.add_argument( + "--get-code-distribution", + action='store_true', + help="whether to get the code distribution (for test)", + ) + + # relative pos enc + parser.add_argument( + "--relative-position-embedding", + action='store_true', + help="whether to use relative position embedding", + ) + parser.add_argument( + "--num-buckets", + type=int, + default=320, + help="num of buckets for relative position embedding", + ) + parser.add_argument( + "--max-distance", + type=int, + default=1280, + help="max distance for relative position embedding", + ) + parser.add_argument( + "--encoder-max-relative-position", + type=int, + help="max distance for relative position embedding in encoder", + ) + parser.add_argument( + "--decoder-max-relative-position", + type=int, + help="max distance for relative position embedding in decoder", + ) + + # hubert feature extractor + parser.add_argument( + "--conv-feature-layers", + type=str, + help= "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]", + ) + parser.add_argument( + "--conv-bias", + action='store_true', + help="include bias in conv encoder", + ) + parser.add_argument( + "--extractor-mode", + choices=["default", "layer_norm"], + help="mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + ) + + # others + parser.add_argument( + "--bert-init", + action='store_true', + help="initilize as bert", + ) + parser.add_argument( + "--unb-enc-layer", + type=int, + default=-1, + help="which layer's output is used as the input of decoder", + ) + + # Encoder, Decoder + @classmethod + def build_encoder(cls, args, dictionary=None, embed_tokens=None): + return TransformerEncoder(args, dictionary, embed_tokens) + + @classmethod + def build_decoder(cls, args): + return TransformerDecoder(args) + + # Encoder Prenet + @classmethod + def build_text_encoder_prenet(cls, embed_tokens, args): + return TextEncoderPrenet(embed_tokens, args) + + @classmethod + def build_speech_encoder_prenet(cls, args): + return SpeechEncoderPrenet(args) + + # Decoder Prenet + @classmethod + def build_text_decoder_prenet(cls, embed_tokens, args): + return TextDecoderPrenet(embed_tokens, args) + + @classmethod + def build_speech_decoder_prenet(cls, odim, args): + return SpeechDecoderPrenet(odim, args) + + # Decoder Postnet + @classmethod + def build_text_decoder_postnet(cls, embed_tokens, dictionary, args): + return TextDecoderPostnet(embed_tokens, dictionary, args) + + @classmethod + def build_speaker_decoder_postnet(cls, embed_dim, class_num, args): + return SpeakerDecoderPostnet(embed_dim, class_num, args) + + @classmethod + def build_speech_decoder_postnet(cls, odim, args): + return SpeechDecoderPostnet(odim, args) + + @classmethod + def build_speech_encoder_postnet(cls, dictionaries, args): + return SpeechEncoderPostnet(dictionaries, args) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim, max_num_embeddings=None): + num_embeddings = len(dictionary) + if max_num_embeddings is not None and isinstance(max_num_embeddings, int): + num_embeddings = min(num_embeddings, max_num_embeddings) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet: + max_num_embeddings = 3 # at index 2 + else: + max_num_embeddings = None + + text_decoder_embed_tokens = build_embedding( + task.dicts["text"], args.decoder_embed_dim, max_num_embeddings + ) + + if args.share_input_output_embed: + text_encoder_embed_tokens = text_decoder_embed_tokens + else: + text_encoder_embed_tokens = build_embedding( + task.dicts["text"], args.encoder_embed_dim + ) + + speech_odim = args.speech_odim + if "text" in task.dicts: + encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens) + else: + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args) + + text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args) + speech_encoder_prenet = cls.build_speech_encoder_prenet(args) + + text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args) + if getattr(args, "sid_pooling_layer", None) == "decoder-las": + speech_decoder_prenet = cls.build_speech_encoder_prenet(args) + else: + speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args) + + text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args) + speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args) + + if getattr(args, "sid_t5_postnet", False): + speaker_decoder_postnet = None + else: + if task.t5_task == "s2c": + speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args) + else: + speaker_decoder_postnet = None + + if "hubert" in task.dicts: + speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args) + else: + speech_encoder_postnet = None + + return cls( + args, + encoder, decoder, + text_encoder_prenet, speech_encoder_prenet, + text_decoder_prenet, speech_decoder_prenet, + text_decoder_postnet, speech_decoder_postnet, + speaker_decoder_postnet, speech_encoder_postnet, + ) + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + def get_normalized_probs_for_ctc(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out_for_ctc"][0] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, sample, net_output, is_masked=True): + if "logit_m_list" in net_output: + logits_list = self.get_logits(net_output, is_masked) + targets_list = [ + x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list + ] + return targets_list + else: + return sample["target"] + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + if "prob_perplexity" in net_output: + extra_losses.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + names.append("prob_perplexity") + + return extra_losses, names + + def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs. + """ + assert source is not None or src_tokens is not None + # padding_mask is not none only when input is waveform + if source is None and padding_mask is None and not feature_only: + input_type = 'text' + else: + input_type = 'speech' + + if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2: + output_type = 'text' + codebook_out = {} + else: + output_type = 'speech' + + if task_name is not None and task_name == "s2c": + if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False): + sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num) + else: + sid_target = None + target_list = None + + # Encoder Prenet + if input_type == 'text': + encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens) + else: + if target_list is not None: + encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask) + encoder_input, features_pen, mask_indices, target_list = encoder_input + else: + encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training) + # shuffle a batch of inputs of encoder + if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False): + shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device) + encoder_input = torch.index_select(encoder_input, 1, shuffle_index) + encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index) + if getattr(self.args, "sid_encoder_cls", None) == "encoder": + prev_output_tokens = torch.zeros_like(prev_output_tokens) + encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask) + + # Encoder: T x B x C + encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer) + + if task_name is not None and task_name == 'speech_pretrain' and feature_only: + return encoder_output["encoder_out"][0].transpose(0, 1) + + if task_name is not None and task_name == 's2c': + if self.args.sid_pooling_layer == "encoder": + return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None + elif self.args.sid_pooling_layer == "encoder-cls": + return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None + elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False): + return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None + + if target_list is not None: + hubert_results = self.hubert_layer( + encoder_output["encoder_out"][0].transpose(0, 1), + encoder_padding_mask, + mask_indices, + target_list + ) + + hubert_results['features_pen'] = features_pen + + if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None: + # Change the encoder output to decoder input once set unb-enc-layer + encoder_output["encoder_out"] = encoder_output["decoder_input"] + + if self.use_codebook: + q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1)) + + # q["x"]: B x T x C + # Sample indexs according to the codebook prob + random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)] + # Make weight for q + q_w = q["x"].new_zeros(q["x"].size(1)) + q_w[random_idx] = 1.0 + # Combine quantized codes and encoder output + encoder_output["encoder_out"][0] = ( + q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1) + ).transpose(0, 1) + + # encoder_output["encoder_out"][0] = q["x"].transpose(0, 1) + if output_type == 'speech': + hubert_results["prob_perplexity"] = q["prob_perplexity"] + hubert_results["code_perplexity"] = q["code_perplexity"] + hubert_results["num_vars"] = q["num_vars"] + hubert_results["temp"] = q["temp"] + elif output_type == 'text': + codebook_out["prob_perplexity"] = q["prob_perplexity"] + codebook_out["code_perplexity"] = q["code_perplexity"] + codebook_out["num_vars"] = q["num_vars"] + codebook_out["temp"] = q["temp"] + + if only_hubert and target_list is not None: + return hubert_results, None + + if only_ctc and task_name is not None and task_name == "s2t": + return None, encoder_output + elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None: + return encoder_output + + # Decoder Prenet + if output_type == 'text': + # _ is the incremental state + prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens) + if task_name is not None and task_name == 's2c': + prev_output_tokens = torch.zeros_like(prev_output_tokens) + else: + # integrate speaker embedding + if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None: + # Decoder Prenet + prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs) + else: + if self.spk_embed_dim is not None: + encoder_output["encoder_out"] = [self._integrate_with_spk_embed( + encoder_output["encoder_out"][0].transpose(0, 1), spkembs + ).transpose(0, 1)] + + prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths) + + # BART Sequence Classification: cat + feature before decoder + if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las": + decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths) + prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False) + + # SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder + if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source": + prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths) + + # Decoder + decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output, + full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False), + alignment_layer=(-1 if target_list is None and output_type == 'speech' else None)) + # Decoder Postnet + if task_name is not None and task_name == 's2c': + if not getattr(self.args, "sid_t5_postnet", False): + if self.args.sid_pooling_layer == "decoder": + return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None + elif self.args.sid_pooling_layer == "decoder-las": + indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64) + indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2)) + return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None + else: + return (self.text_decoder_postnet(decoder_output), None), encoder_output + + # SE predict: masking, target, delta. Ensure reduction factor 1 + if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None: + assert self.reduction_factor == 1, f"{self.reduction_factor} != 1" + before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output) + se_predict = getattr(self.args, "se_predict") + if se_predict == "masking": + before_outs = torch.sigmoid(before_outs) * src_tokens + after_outs = torch.sigmoid(after_outs) * src_tokens + return before_outs, after_outs, logits, extra['attn'][0] + elif se_predict == "target": + return before_outs, after_outs, logits, extra['attn'][0] + elif se_predict == "delta": + before_outs = before_outs - src_tokens + after_outs = after_outs - src_tokens + return before_outs, after_outs, logits, extra['attn'][0] + else: + raise ValueError(f"{se_predict} not in [masking, target, delta]") + + if task_name is not None and task_name == 's2t': + #return self.text_decoder_postnet(decoder_output), None + return (self.text_decoder_postnet(decoder_output), None), encoder_output + if output_type == 'text': + return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output + else: + if target_list is not None: + return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)) + else: + return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],) + + def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True): + """ + encoder_input: [B, T, C] + encoder_padding_mask: [B, T] + """ + if hasattr(self, "text_decoder_prenet"): + if isinstance(pad_input, tuple): + repeat_cls_vector, repeat_cls_mask = pad_input + else: + repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input) + + if encoder_padding_mask is not None: + bsz = encoder_input.size(0) + tsz = encoder_input.size(1) + encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0 + if repeat_cls_mask is None: + mask_size = (encoder_padding_mask.size(0), 1) + mask_type = encoder_padding_mask.dtype + repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0 + ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1) + + if cls_first: + ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1) + else: + ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1) + mask_size = (encoder_padding_mask.size(0), 1) + mask_type = encoder_padding_mask.dtype + repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0 + encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1) + indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1) + indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0) + ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \ + + repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2) + + return ret_encoder_input, ret_encoder_padding_mask + + def _integrate_with_spk_embed(self, hs, spembs): + """Integrate speaker embedding with hidden states. + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args=None, + ): + """NOT STRICT Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + # self.prune_modules(model_cfg.modules_filter) + model_dict_size = self.text_decoder_postnet.output_projection.out_features + ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0) + if model_dict_size != ckpt_dict_size: + # reset dictionary-related modules, such as embedding table and encoder ctc embed + logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}") + logger.info(f"reset model dictionary with size of {model_dict_size}") + removed_keys = [ + key for key in state_dict.keys() if any( + key.startswith(previ) for previ in [ + "encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet" + ] + ) + ] + for key in removed_keys: + state_dict.pop(key, None) + logger.info(f"removed loaded checkpoint: {key}") + for m in self._modules.keys(): + m_state_dict = { + key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.") + } + if hasattr(self, m): + self._modules[m].load_state_dict(m_state_dict, False) + return self + + def prune_modules(self, modules_filter=None): + """Prune unused modules for specific tasks.""" + if modules_filter is None: + return + elif modules_filter == "s2c": + if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet + if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las": + del self.speech_decoder_prenet + if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet + if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet + if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet + if hasattr(self.encoder, "proj"): self.encoder.proj = None + if hasattr(self, "projection"): del self.projection + if hasattr(self, "quantizer"): del self.quantizer + if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False): + if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module + if hasattr(self.decoder, "layers"): del self.decoder.layers + if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm + if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet + elif modules_filter == "s2s": + if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet + if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet + if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet + if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet + if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet + if hasattr(self.encoder, "proj"): self.encoder.proj = None + if hasattr(self, "projection"): del self.projection + if hasattr(self, "quantizer"): del self.quantizer + elif modules_filter == "t2s": + if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet + if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet + if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet + if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet + if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet + if hasattr(self.encoder, "proj"): self.encoder.proj = None + if hasattr(self, "projection"): del self.projection + if hasattr(self, "quantizer"): del self.quantizer + elif modules_filter == "s3prl": + # remain the encoder and the pre/post net + if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module + if hasattr(self.decoder, "layers"): del self.decoder.layers + if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm + if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet + if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet + if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet + if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet + if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet + if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet + if hasattr(self.encoder, "proj"): self.encoder.proj = None + if hasattr(self, "projection"): del self.projection + if hasattr(self, "quantizer"): del self.quantizer + + def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]): + """A TorchScript-compatible version of forward. + + Encoders which use additional arguments may want to override + this method for TorchScript compatibility. + """ + if torch.jit.is_scripting(): + return self.forward_encoder( + source=net_input["source"], + padding_mask=net_input["padding_mask"] + ) + else: + return self.forward_encoder_non_torchscript(net_input) + + @torch.jit.unused + def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" + } + return self.forward_encoder(**encoder_input) + + def forward_encoder(self, source, padding_mask=None): + # Encoder Prenet + encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False) + + # Encoder + encoder_output = self.encoder(encoder_input, encoder_padding_mask) + + return encoder_output + + def forward_text_encoder(self, src_tokens): + # Text Encoder Prenet + encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens) + + # Encoder + encoder_output = self.encoder(encoder_input, encoder_padding_mask) + + return encoder_output + + def forward_decoder(self, tokens, encoder_out, incremental_state): + # Decoder Prenet + prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state) + + # Decoder + decoder_output, extra = self.decoder( + prev_output_tokens, + tgt_mask, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + + # Decoder Postnet + return self.text_decoder_postnet(decoder_output), extra + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def generate_class(self, source, prev_output_tokens, **kwargs): + encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"]) + + prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {}) + prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS] + + decoder_output, extra = self.decoder( + prev_output_tokens, + tgt_mask, + encoder_out=encoder_out, + ) + + decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1)) + + pred_class = decoder_out.argmax(1) + return pred_class + + def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs): + assert source is not None or src_tokens is not None + + threshold = kwargs.get("threshold", 0.5) + minlenratio = kwargs.get("threshold", 0.0) + + if source is None: + assert src_tokens.size(0) == 1 + encoder_out = self.forward_text_encoder(src_tokens) + maxlenratio = kwargs.get("threshold", 20.0) + else: + assert source.size(0) == 1 + encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"]) + maxlenratio = kwargs.get("threshold", 10.0) + + if spkembs is not None and self.spk_embed_integration_type != "pre": + encoder_out["encoder_out"] = [self._integrate_with_spk_embed( + encoder_out["encoder_out"][0].transpose(0, 1), spkembs + ).transpose(0, 1)] + spkembs = None + + maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor) + minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor) + + idx = 0 + ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim) + outs, probs = [], [] + + # forward decoder step-by-step + if isinstance(self.decoder, FairseqIncrementalDecoder): + incremental_states = {} + else: + incremental_states = None + attns = [] + while True: + # update index + idx += 1 + # calculate output and stop prob at idx-th step + decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs) + z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1) + outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...] + probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...] + + # update next inputs + ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim) + attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0)) + # check whether to finish generation + if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: + # check mininum length + if idx < minlen: + continue + outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L) + if self.speech_decoder_postnet.postnet is not None: + outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L) + outs = outs.transpose(2, 1).squeeze(0) # (L, odim) + probs = torch.cat(probs, dim=0) + attn = torch.cat(attns, dim=2) + break + + if outs.size(0) == maxlen: + logging.warning("output length reaches maximum length") + return outs, probs, attn + + +@register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer") +def base_architecture(args): + # Transformer + args.bert_init = getattr(args, "bert_init", False) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS) + args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS) + + # Espnet related, including prenet, postnet + args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0) + args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0) + args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0) + args.use_batch_norm = getattr(args, "use_batch_norm", True) + args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0) + args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True) + args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_chans = getattr(args, "postnet_chans", 256) + args.postnet_filts = getattr(args, "postnet_filts", 5) + args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5) + args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5) + args.dprenet_layers = getattr(args, "dprenet_layers", 2) + args.dprenet_units = getattr(args, "dprenet_units", 256) + args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0) + args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0) + args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre") + args.spk_embed_dim = getattr(args, "spk_embed_dim", 512) + args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1) + args.reduction_factor = getattr(args, "reduction_factor", 2) + args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1) + args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1) + args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5) + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + # Convolutional subsampler + args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv") + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") + args.conv_channels = getattr(args, "conv_channels", 1024) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.share_input_output_embed = getattr(args, "share_input_output_embed", False) + args.share_ctc_embed = getattr(args, "share_ctc_embed", False) + args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0) + args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0) + args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None) + + ## sid + args.sid_embed_dim = getattr(args, "sid_embed_dim", 128) + args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder") + args.softmax_scale = getattr(args, "softmax_scale", 1) + args.softmax_margin = getattr(args, "softmax_margin", 0) + args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False) + args.modules_filter = getattr(args, "modules_filter", None) + + ## Hubert + args.conv_pos = getattr(args, "conv_pos", 128) + args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) + args.target_glu = getattr(args, "target_glu", False) + args.logit_temp = getattr(args, "logit_temp", 0.1) + args.final_dim = getattr(args, "final_dim", 256) + args.untie_final_proj = getattr(args, "untie_final_proj", True) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1) + args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True) + # hubert feature extractor + args.extractor_mode = getattr(args, "extractor_mode", "default") + args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2") + args.conv_bias = getattr(args, "conv_bias", False) + # mask + args.hubert_mask_length = getattr(args, "hubert_mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.0) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_min_space = getattr(args, "mask_min_space", 1) + # channel mask + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) + # loss computation + args.skip_masked = getattr(args, "skip_masked", False) + args.skip_nomask = getattr(args, "skip_nomask", False) + # conv Pos + args.use_conv_pos = getattr(args, "use_conv_pos", False) + args.use_sinc_pos = getattr(args, "use_sinc_pos", False) + + # codebook + args.use_codebook = getattr(args, "use_codebook", False) + args.latent_vars = getattr(args, "latent_vars", 100) + args.latent_groups = getattr(args, "latent_groups", 2) + args.latent_dim = getattr(args, "latent_dim", 0) + args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995)) + args.quantizer_depth = getattr(args, "quantizer_depth", 1) + args.quantizer_factor = getattr(args, "quantizer_factor", 3) + args.codebook_prob = getattr(args, "codebook_prob", 0.5) + + # Relative pos embed + args.relative_position_embedding = getattr(args, "relative_position_embedding", False) + args.num_buckets = getattr(args, "num_buckets", 320) + args.max_distance = getattr(args, "max_distance", 1280) + args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160) + args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160) + +@register_model_architecture("artst_transformer", "artst_transformer_base") +def artst_transformer_base(args): + args.use_conv_pos = getattr(args, "use_conv_pos", True) + args.use_sinc_pos = getattr(args, "use_sinc_pos", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.layer_norm_first = getattr(args, "layer_norm_first", False) + args.relative_position_embedding = getattr(args, "relative_position_embedding", True) + args.dropout = getattr(args, "dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05) + args.mask_prob = getattr(args, "mask_prob", 0.80) + base_architecture(args) + +@register_model_architecture("artst_transformer", "artst_transformer_large") +def artst_transformer_large(args): + args.use_conv_pos = getattr(args, "use_conv_pos", True) + args.use_sinc_pos = getattr(args, "use_sinc_pos", True) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.layer_norm_first = getattr(args, "layer_norm_first", True) + args.relative_position_embedding = getattr(args, "relative_position_embedding", True) + args.dropout = getattr(args, "dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) + args.extractor_mode = getattr(args, "extractor_mode", "layer_norm") + args.final_dim = getattr(args, "final_dim", 768) + args.mask_prob = getattr(args, "mask_prob", 0.80) + base_architecture(args) + +@register_model_architecture("artst_transformer", "artst_transformer_base_asr") +def artst_transformer_base_asr(args): + args.use_conv_pos = getattr(args, "use_conv_pos", True) + args.use_sinc_pos = getattr(args, "use_sinc_pos", True) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.layer_norm_first = getattr(args, "layer_norm_first", False) + args.relative_position_embedding = getattr(args, "relative_position_embedding", True) + args.dropout = getattr(args, "dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1) + args.mask_prob = getattr(args, "mask_prob", 0.75) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_channel_length = getattr(args, "mask_channel_length", 64) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.max_text_positions = getattr(args, "max_text_positions", 600) + base_architecture(args) diff --git a/artst/models/modules/__init__.py b/artst/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/artst/models/modules/__pycache__/__init__.cpython-38.pyc b/artst/models/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbd8dd8f895833abda75ac75748818863e8dac9 Binary files /dev/null and b/artst/models/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/decoder.cpython-38.pyc b/artst/models/modules/__pycache__/decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4b30cc3d861471204538738e51efcd7352f28ed Binary files /dev/null and b/artst/models/modules/__pycache__/decoder.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/encoder.cpython-38.pyc b/artst/models/modules/__pycache__/encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9858130d42e5af680ebfee04abad78d32cd4c7 Binary files /dev/null and b/artst/models/modules/__pycache__/encoder.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc b/artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c86ab8ae26e42d6035cc436c7a07ad41b608fb Binary files /dev/null and b/artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc b/artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d57379e07f40d55e342567bbc47003816f52eef4 Binary files /dev/null and b/artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc b/artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de6c9262fa356371511b26561608caa02cb9fb78 Binary files /dev/null and b/artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc b/artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444e0be4e2d64eed654e9d345a042a0e12171760 Binary files /dev/null and b/artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc b/artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31234b6df22301bf83ba1ece39690f5314f4eac7 Binary files /dev/null and b/artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc b/artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a80d1b8acbeb5db0f3ac5722a5e1826dc5e9d2d Binary files /dev/null and b/artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/text_decoder_postnet.cpython-38.pyc b/artst/models/modules/__pycache__/text_decoder_postnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5db269e11ffb0dcbac7db1f0de6ee79fa31b2db8 Binary files /dev/null and b/artst/models/modules/__pycache__/text_decoder_postnet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/text_decoder_prenet.cpython-38.pyc b/artst/models/modules/__pycache__/text_decoder_prenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5e6d3b502124f5e5326d8bff1f46e30daccd4e Binary files /dev/null and b/artst/models/modules/__pycache__/text_decoder_prenet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/text_encoder_prenet.cpython-38.pyc b/artst/models/modules/__pycache__/text_encoder_prenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5871aad006fd1d2d41cdda1518f2c5fff534665b Binary files /dev/null and b/artst/models/modules/__pycache__/text_encoder_prenet.cpython-38.pyc differ diff --git a/artst/models/modules/__pycache__/transformer_layer.cpython-38.pyc b/artst/models/modules/__pycache__/transformer_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b0d3dcebcf51e3b935012a74c1b1454a1afd39a Binary files /dev/null and b/artst/models/modules/__pycache__/transformer_layer.cpython-38.pyc differ diff --git a/artst/models/modules/decoder.py b/artst/models/modules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c759ed2a2e6743bc9df771fff8b651630ab92319 --- /dev/null +++ b/artst/models/modules/decoder.py @@ -0,0 +1,323 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import ( + FairseqIncrementalDecoder, +) +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from torch import Tensor + +from .encoder import RelativePositionalEncoding +from .transformer_layer import TransformerDecoderLayer + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + no_encoder_attn=False, + ): + self.args = args + super().__init__(None) + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.decoder_layerdrop = args.decoder_layerdrop + # self.max_s_positions = args.max_target_positions + export = getattr(args, "export", False) + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + self.num_layers = len(self.layers) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(args.decoder_embed_dim, eps=args.layer_norm_eps, export=export) + else: + self.layer_norm = None + + if args.relative_position_embedding: + self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.decoder_max_relative_position) + + def build_decoder_layer(self, args, no_encoder_attn=False): + layer = TransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, has_relative_attention_bias=args.relative_position_embedding) + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint + else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens, + tgt_mask, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + tgt_mask, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + return x, extra + + def extract_features( + self, + prev_output_tokens, + tgt_mask, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + tgt_mask, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + tgt_mask, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs = prev_output_tokens.size(0) + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + assert ( + enc.size()[1] == bs + ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + # B x T x C -> T x B x C + x = prev_output_tokens.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or tgt_mask is not None: + self_attn_padding_mask = tgt_mask + + ## relative position embedding + if self.args.relative_position_embedding: + x_len = x.shape[0] + pos_seq = torch.arange(0, x_len).long().to(x.device) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + pos_k, pos_v = self.pos_emb(pos_seq) + else: + pos_k = None + + # decoder layers + attn_list = [] + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer or alignment_layer == -1)), + need_head_weights=bool((idx == alignment_layer or alignment_layer == -1)), + pos_bias=pos_k, + ) + inner_states.append(x) + if layer_attn is not None and (idx == alignment_layer or alignment_layer == -1): + attn = layer_attn.float().to(x) + attn_list.append(attn.transpose(0, 1)) + + if attn is not None and len(attn_list) == 1: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": [attn if len(attn_list) <= 1 else attn_list], "inner_states": inner_states} + + # def max_positions(self): + # """Maximum output length supported by the decoder.""" + # return self.max_target_positions + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim], device=tensor.device)), 1, + ) + else: + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) diff --git a/artst/models/modules/encoder.py b/artst/models/modules/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab0b901fbdccab1e3e864a6dd9482c62ea38ee3 --- /dev/null +++ b/artst/models/modules/encoder.py @@ -0,0 +1,380 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +import contextlib +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, +) +from fairseq.modules import ( + FairseqDropout, + LayerNorm, + TransformerEncoderLayer, +) +from torch import Tensor +from .transformer_layer import TransformerSentenceEncoderLayer + + + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +class RelativePositionalEncoding(torch.nn.Module): + def __init__(self, d_model, maxlen=1000, embed_v=False): + super(RelativePositionalEncoding, self).__init__() + + self.d_model = d_model + self.maxlen = maxlen + self.pe_k = torch.nn.Embedding(2*maxlen, d_model) + if embed_v: + self.pe_v = torch.nn.Embedding(2*maxlen, d_model) + self.embed_v = embed_v + + + def forward(self, pos_seq): + pos_seq[pos_seq < -self.maxlen] = -self.maxlen + pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 + pos_seq = pos_seq + self.maxlen + if self.embed_v: + return self.pe_k(pos_seq), self.pe_v(pos_seq) + else: + return self.pe_k(pos_seq), None + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, tgt_dict=None, embed_tokens=None): + self.args = args + super().__init__(None) + self.register_buffer("version", torch.Tensor([3])) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.encoder_layerdrop = args.encoder_layerdrop + self.freeze_encoder_updates = args.freeze_encoder_updates + if args.no_freeze_encoder_layer is not None: + self.no_freeze_encoder_layer = eval(args.no_freeze_encoder_layer) + else: + self.no_freeze_encoder_layer = None + self.num_updates = 0 + export = getattr(args, "export", False) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) + self.num_layers = len(self.layers) + + self.use_sent_enc_layer = args.use_sent_enc_layer + self.unb_enc_layer = getattr(args, "unb_enc_layer", -1) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(args.encoder_embed_dim, eps=args.layer_norm_eps, export=export) + + if args.share_ctc_embed and embed_tokens is not None: + self.proj = nn.Linear( + embed_tokens.weight.shape[1], + embed_tokens.weight.shape[0], + bias=False, + ) + self.proj.weight = embed_tokens.weight + elif tgt_dict is not None: + self.proj = Linear(args.encoder_embed_dim, len(tgt_dict)) + else: + self.proj = None + + if args.relative_position_embedding: + self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.encoder_max_relative_position) + + + def build_encoder_layer(self, args): + if args.use_sent_enc_layer: + layer = TransformerSentenceEncoderLayer( + embedding_dim=args.encoder_embed_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=args.relative_position_embedding, + ) + else: + layer = TransformerEncoderLayer(args) + return layer + + def forward( + self, + encoder_in, + encoder_padding_mask, + return_all_hiddens: bool = False, + tgt_layer=None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + if self.no_freeze_encoder_layer is None: + ft = self.freeze_encoder_updates <= self.num_updates + else: + ft = True + with torch.no_grad() if not ft else contextlib.ExitStack(): + encoder_out = self.forward_scriptable( + encoder_in, encoder_padding_mask, return_all_hiddens, tgt_layer=tgt_layer, + ) + + # CTC and bert + if self.proj: + x_for_ctc = self.proj(self.dropout_module(encoder_out["encoder_out"][0])) + else: + x_for_ctc = None + + encoder_out["encoder_out_for_ctc"] = [x_for_ctc] # T x B x C + + return encoder_out + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + encoder_in, + encoder_padding_mask, + return_all_hiddens: bool = False, + tgt_layer=None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + if self.no_freeze_encoder_layer is not None: + ft = self.freeze_encoder_updates <= self.num_updates + else: + ft = True + with torch.no_grad() if not ft else contextlib.ExitStack(): + # compute padding mask + if not self.use_sent_enc_layer: + has_pads = encoder_in.device.type == "xla" or encoder_padding_mask.any() + + if not self.layer_norm_first: + encoder_in = self.layer_norm(encoder_in) + + encoder_in = self.dropout_module(encoder_in) + + # B x T x C -> T x B x C + x = encoder_in.transpose(0, 1) + + encoder_states = [] + + if return_all_hiddens: + encoder_states.append(x) + + ## relative position embedding + if self.args.relative_position_embedding: + x_len = x.shape[0] + pos_seq = torch.arange(0, x_len).long().to(x.device) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + pos_k, pos_v = self.pos_emb(pos_seq) + else: + pos_k = None + + # encoder layers + r = None + d = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + + with torch.no_grad() if (not ft) and i not in self.no_freeze_encoder_layer else contextlib.ExitStack(): + if not self.training or (dropout_probability > self.encoder_layerdrop) or i == self.unb_enc_layer: + if self.use_sent_enc_layer: + x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, self_attn_mask=None, need_weights=False, pos_bias=pos_k) + # x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, need_weights=False, pos_bias=pos_k) + else: + x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, attn_mask=None) + # x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None) + if i == self.unb_enc_layer: + d = x + + if i == tgt_layer: + r = x + break + + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + + with torch.no_grad() if not ft else contextlib.ExitStack(): + # Finally T x B x C + if self.layer_norm_first: + x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1) + + if r is not None: + x = r + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "decoder_input": [d], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + + if len(encoder_out["encoder_out_for_ctc"]) == 0: + new_x_for_ctc = [] + else: + new_x_for_ctc = [encoder_out["encoder_out_for_ctc"][0].index_select(1, new_order)] + + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] + + if len(encoder_out["decoder_input"]) == 0 or encoder_out["decoder_input"][0] is None: + new_decoder_input = [] + else: + new_decoder_input = [ + encoder_out["decoder_input"][0].index_select(0, new_order) + ] + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "encoder_out_for_ctc": new_x_for_ctc, # T x B x C + "decoder_input": new_decoder_input, + } + + # def max_positions(self): + # """Maximum input length supported by the encoder.""" + # return self.max_source_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + # if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + # weights_key = "{}.embed_positions.weights".format(name) + # if weights_key in state_dict: + # print("deleting {0}".format(weights_key)) + # del state_dict[weights_key] + # state_dict[ + # "{}.embed_positions._float_tensor".format(name) + # ] = torch.FloatTensor(1) + for i in range(self.num_layers): + # update layer norms + if not isinstance(self.layers[i], TransformerSentenceEncoderLayer): + self.layers[i].upgrade_state_dict_named( + state_dict, "{}.layers.{}".format(name, i) + ) + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + \ No newline at end of file diff --git a/artst/models/modules/multihead_attention.py b/artst/models/modules/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d4cb0f319bcf9f34a2b22b11f284f30cb8958096 --- /dev/null +++ b/artst/models/modules/multihead_attention.py @@ -0,0 +1,525 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.has_relative_attention_bias = has_relative_attention_bias + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and not self.has_relative_attention_bias + ): + assert key is not None and value is not None + # Hawau: + if query.dtype != attn_mask.dtype: + attn_mask = attn_mask.type(query.dtype) + # My code ends here + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + if position_bias is not None and self.has_relative_attention_bias: ## first order + ## position_bias: [241, 241, 64] + #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241] + reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64] + #print ("reshape_q: ", reshape_q.size()) + B = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) + #print ("B: ", B.size()) ## [241, 492, 241] + #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1)) + B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1)) + #print ("B 2: ", B.size()) + attn_weights += B + else: + position_bias = None + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size( + 0 + ) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/artst/models/modules/speaker_decoder_postnet.py b/artst/models/modules/speaker_decoder_postnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6d82a43e7f2592901e8aca915b6418174b5d2abe --- /dev/null +++ b/artst/models/modules/speaker_decoder_postnet.py @@ -0,0 +1,196 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import torch.nn as nn +import math +import torch +import torch.nn.functional as F + + +class AngularMargin(nn.Module): + """ + An implementation of Angular Margin (AM) proposed in the following + paper: '''Margin Matters: Towards More Discriminative Deep Neural Network + Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317) + + Arguments + --------- + margin : float + The margin for cosine similiarity + scale : float + The scale for cosine similiarity + + Return + --------- + predictions : torch.Tensor + + Example + ------- + >>> pred = AngularMargin() + >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) + >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) + >>> predictions = pred(outputs, targets) + >>> predictions[:,0] > predictions[:,1] + tensor([ True, False, True, False]) + """ + + def __init__(self, margin=0.0, scale=1.0): + super(AngularMargin, self).__init__() + self.margin = margin + self.scale = scale + + def forward(self, outputs, targets): + """Compute AM between two tensors + + Arguments + --------- + outputs : torch.Tensor + The outputs of shape [N, C], cosine similarity is required. + targets : torch.Tensor + The targets of shape [N, C], where the margin is applied for. + + Return + --------- + predictions : torch.Tensor + """ + outputs = outputs - self.margin * targets + return self.scale * outputs + + +class AdditiveAngularMargin(AngularMargin): + """ + An implementation of Additive Angular Margin (AAM) proposed + in the following paper: '''Margin Matters: Towards More Discriminative Deep + Neural Network Embeddings for Speaker Recognition''' + (https://arxiv.org/abs/1906.07317) + + Arguments + --------- + margin : float + The margin for cosine similiarity, usually 0.2. + scale: float + The scale for cosine similiarity, usually 30. + + Returns + ------- + predictions : torch.Tensor + Tensor. + Example + ------- + >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) + >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) + >>> pred = AdditiveAngularMargin() + >>> predictions = pred(outputs, targets) + >>> predictions[:,0] > predictions[:,1] + tensor([ True, False, True, False]) + """ + + def __init__(self, margin=0.0, scale=1.0, easy_margin=False): + super(AdditiveAngularMargin, self).__init__(margin, scale) + self.easy_margin = easy_margin + + self.cos_m = math.cos(self.margin) + self.sin_m = math.sin(self.margin) + self.th = math.cos(math.pi - self.margin) + self.mm = math.sin(math.pi - self.margin) * self.margin + + def forward(self, outputs, targets): + """ + Compute AAM between two tensors + + Arguments + --------- + outputs : torch.Tensor + The outputs of shape [N, C], cosine similarity is required. + targets : torch.Tensor + The targets of shape [N, C], where the margin is applied for. + + Return + --------- + predictions : torch.Tensor + """ + cosine = outputs.float() + sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) + phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) + if self.easy_margin: + phi = torch.where(cosine > 0, phi, cosine) + else: + phi = torch.where(cosine > self.th, phi, cosine - self.mm) + outputs = (targets * phi) + ((1.0 - targets) * cosine) + return self.scale * outputs + + +class SpeakerDecoderPostnet(nn.Module): + """Speaker Identification Postnet. + + Arguments + --------- + embed_dim : int + The size of embedding. + class_num: int + The number of classes. + args : Namespace + + Return + --------- + embed : torch.Tensor + output : torch.Tensor + """ + + def __init__(self, embed_dim, class_num, args): + super(SpeakerDecoderPostnet, self).__init__() + self.embed_dim = embed_dim + self.class_num = class_num + self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False) + self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False) + self.normalize_postnet = getattr(args, "sid_normalize_postnet", False) + self.softmax_head = getattr(args, "sid_softmax_type", "softmax") + if not self.no_pooling_bn: + self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim) + else: + self.bn_pooling = None + if not self.no_embed_postnet: + self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False) + self.bn_embedding = nn.BatchNorm1d(embed_dim) + else: + self.output_embedding = None + self.bn_embedding = None + self.embed_dim = args.decoder_output_dim + self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False) + if self.softmax_head == "amsoftmax": + self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale) + elif self.softmax_head == "aamsoftmax": + self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin) + else: + self.output_layer = None + if self.output_embedding is not None: + nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5) + nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5) + + def forward(self, x, target=None): + """ + Parameters + ---------- + x : torch.Tensor of shape [batch, channel] or [batch, time, channel] + target : torch.Tensor of shape [batch, channel] + """ + if self.bn_pooling is not None: + x = self.bn_pooling(x) + if self.output_embedding is not None and self.bn_embedding is not None: + embed = self.bn_embedding(self.output_embedding(x)) + else: + embed = x + if self.output_layer is not None or self.normalize_postnet: + x_norm = F.normalize(embed, p=2, dim=1) + w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim] + output = F.linear(x_norm, w_norm) + if self.training and target is not None and self.output_layer is not None: + output = self.output_layer(output, target) + else: + output = self.output_projection(embed) + return output, embed diff --git a/artst/models/modules/speech_decoder_postnet.py b/artst/models/modules/speech_decoder_postnet.py new file mode 100644 index 0000000000000000000000000000000000000000..87e9eeb679b09d2fa6dedab58ad5811e6d1049a5 --- /dev/null +++ b/artst/models/modules/speech_decoder_postnet.py @@ -0,0 +1,75 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import contextlib +import torch +import torch.nn as nn + +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet + + +class SpeechDecoderPostnet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + odim, + args, + ): + super(SpeechDecoderPostnet, self).__init__() + # define decoder postnet + # define final projection + self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor) + self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor) + + # define postnet + self.postnet = ( + None + if args.postnet_layers == 0 + else Postnet( + idim=0, + odim=odim, + n_layers=args.postnet_layers, + n_chans=args.postnet_chans, + n_filts=args.postnet_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.postnet_dropout_rate, + ) + ) + + self.odim = odim + self.num_updates = 0 + self.freeze_decoder_updates = args.freeze_decoder_updates + + def forward(self, zs): + ft = self.freeze_decoder_updates <= self.num_updates + with torch.no_grad() if not ft else contextlib.ExitStack(): + # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + # (B, Lmax//r, r) -> (B, Lmax//r * r) + logits = self.prob_out(zs).view(zs.size(0), -1) + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + return before_outs, after_outs, logits + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates diff --git a/artst/models/modules/speech_decoder_prenet.py b/artst/models/modules/speech_decoder_prenet.py new file mode 100644 index 0000000000000000000000000000000000000000..061b7af3ebd667781ccaf20d6300d7d5db7f4555 --- /dev/null +++ b/artst/models/modules/speech_decoder_prenet.py @@ -0,0 +1,109 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import contextlib +import torch +import torch.nn as nn + +import torch.nn.functional as F +from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask + + +class SpeechDecoderPrenet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + odim, + args, + ): + super(SpeechDecoderPrenet, self).__init__() + # define decoder prenet + if args.dprenet_layers != 0: + # decoder prenet + decoder_input_layer = torch.nn.Sequential( + TacotronDecoderPrenet( + idim=odim, + n_layers=args.dprenet_layers, + n_units=args.dprenet_units, + dropout_rate=args.dprenet_dropout_rate, + ), + torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim), + ) + else: + decoder_input_layer = "linear" + + pos_enc_class = ( + ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding + ) + + if decoder_input_layer == "linear": + self.decoder_prenet = torch.nn.Sequential( + torch.nn.Linear(odim, args.decoder_embed_dim), + torch.nn.LayerNorm(args.decoder_embed_dim), + torch.nn.Dropout(args.transformer_dec_dropout_rate), + torch.nn.ReLU(), + pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate), + ) + elif isinstance(decoder_input_layer, torch.nn.Module): + self.decoder_prenet = torch.nn.Sequential( + decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions) + ) + + if args.spk_embed_integration_type == 'pre': + self.spkembs_layer = torch.nn.Sequential( + torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU() + ) + self.num_updates = 0 + self.freeze_decoder_updates = args.freeze_decoder_updates + + def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None): + ft = self.freeze_decoder_updates <= self.num_updates + with torch.no_grad() if not ft else contextlib.ExitStack(): + prev_output_tokens = self.decoder_prenet(prev_output_tokens) + + if spkembs is not None: + spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1) + prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1)) + + if tgt_lengths_in is not None: + tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1)) + else: + tgt_frames_mask = None + return prev_output_tokens, tgt_frames_mask + + def _source_mask(self, ilens): + """Make masks for self-attention. + Args: + ilens (LongTensor or List): Batch of lengths (B,). + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [[1, 1, 1, 0, 0]]], dtype=torch.uint8) + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates diff --git a/artst/models/modules/speech_encoder_postnet.py b/artst/models/modules/speech_encoder_postnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8a368aa2ed82b53afbabb7418c50c929cdb34768 --- /dev/null +++ b/artst/models/modules/speech_encoder_postnet.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +import torch.nn as nn +import torch + + +logger = logging.getLogger(__name__) + +class SpeechEncoderPostnet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__(self, dictionaries, args): + super(SpeechEncoderPostnet, self).__init__() + # modules below are not needed during fine-tuning + self.target_glu = args.target_glu + self.skip_masked = args.skip_masked + self.skip_nomask = args.skip_nomask + self.logit_temp = args.logit_temp + + final_dim = ( + args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + ) + if any([d is None for d in dictionaries]): + logger.info( + "cannot find dictionary. assume will be used for fine-tuning" + ) + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + self.untie_final_proj = args.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + args.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity( + x.float(), targets.float(), dim=-1 + ).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward(self, x, padding_mask, mask_indices, target_list): + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate( + zip(proj_x_m_list, target_list) + ) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate( + zip(proj_x_u_list, target_list) + ) + ] + else: + logit_u_list = [None for _ in target_list] + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + } + + return result diff --git a/artst/models/modules/speech_encoder_prenet.py b/artst/models/modules/speech_encoder_prenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c724ad5a1ccf932996dedf2e532fdb188d23df --- /dev/null +++ b/artst/models/modules/speech_encoder_prenet.py @@ -0,0 +1,373 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +import math +import torch +import contextlib +from typing import List, Tuple +import torch.nn as nn + +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.data.data_utils import compute_mask_indices +from fairseq.modules import ( + PositionalEmbedding, + Fp32GroupNorm, + FairseqDropout, + SamePad, + GradMultiply, + LayerNorm, + Fp32LayerNorm, + TransposeLast, +) +import numpy as np + +logger = logging.getLogger(__name__) + + +class LinearLayer(nn.Module): + def __init__(self, idim, odom, dropout=0): + super(LinearLayer, self).__init__() + self.linear = nn.Sequential( + nn.Linear(idim, odom), + nn.LayerNorm(odom), + nn.Dropout(dropout), + nn.ReLU(), + ) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + return out + + def forward(self, src_tokens, src_lengths): + """ + src_tokens: [B, T, C] + src_lengths: [B] + """ + x = self.linear(src_tokens) + x = x.transpose(0, 1).contiguous() # -> T x B x C + return x, src_lengths + + +class SpeechEncoderPrenet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__(self, args): + super(SpeechEncoderPrenet, self).__init__() + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_embedding: + self.embed_scale = 1.0 + self.padding_idx = 1 + self.freeze_encoder_updates = args.freeze_encoder_updates + self.num_updates = 0 + assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet + feature_enc_layers = eval(args.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=args.extractor_mode, + conv_bias=args.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = ( + args.label_rates * feature_ds_rate / args.sample_rate + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, args.encoder_embed_dim) + if self.embed != args.encoder_embed_dim + else None + ) + + self.use_conv_pos = args.use_conv_pos + self.use_sinc_pos = args.use_sinc_pos + self.use_abs_pos = getattr(args, "use_abs_pos", False) + + self.feature_grad_mult = args.feature_grad_mult + if self.use_conv_pos: + self.layer_norm = LayerNorm(self.embed) + self.pos_conv = nn.Conv1d( + args.encoder_embed_dim, + args.encoder_embed_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}" + if self.use_sinc_pos: + self.embed_positions = PositionalEmbedding( + args.max_speech_positions, args.encoder_embed_dim, self.padding_idx + ) + if self.use_abs_pos: + self.embed_positions = PositionalEmbedding( + args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True + ) + + # Hubert + self.mask_prob = args.mask_prob + self.mask_selection = args.mask_selection + self.mask_other = args.mask_other + self.hubert_mask_length = args.hubert_mask_length + self.no_mask_overlap = args.no_mask_overlap + self.mask_min_space = args.mask_min_space + + self.mask_channel_prob = args.mask_channel_prob + self.mask_channel_selection = args.mask_channel_selection + self.mask_channel_other = args.mask_channel_other + self.mask_channel_length = args.mask_channel_length + self.no_mask_channel_overlap = args.no_mask_channel_overlap + self.mask_channel_min_space = args.mask_channel_min_space + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): + ft = self.freeze_encoder_updates <= self.num_updates + with torch.no_grad() if not ft else contextlib.ExitStack(): + return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask) + + def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): + if self.feature_grad_mult > 0: + x = self.feature_extractor(src_tokens) + x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] + if self.feature_grad_mult != 1.0: + x = GradMultiply.apply(x, self.feature_grad_mult) + else: + with torch.no_grad(): + x = self.feature_extractor(src_tokens) + x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] + x = x.transpose(0, 1) # [batch, length, hidden_size] + + encoder_padding_mask = padding_mask + + x = x.transpose(1, 2) # [batch, hidden_size, length] + if target_list is not None: + x, target_list = self.forward_targets(x, target_list) + features_pen = x.float().pow(2).mean() + x = x.transpose(1, 2) # [batch, length, hidden_size] + x = self.layer_norm(x) + encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask) + if self.post_extract_proj is not None: + x = self.post_extract_proj(x) + x = self.dropout_module(x) + if mask: + x, mask_indices = self.apply_hubert_mask( + x, encoder_padding_mask + ) + else: + x = x + mask_indices = None + + if self.use_conv_pos: + positions = self.pos_conv(x.transpose(1, 2)) + positions = positions.transpose(1, 2) + #else: + # positions = self.embed_positions(encoder_padding_mask) + x = x + positions + + if self.use_sinc_pos: + positions = self.embed_positions(encoder_padding_mask) + x = x + positions + + # x = self.dropout_module(x) + + if require_feat_pen: + return (x, features_pen, mask_indices, target_list), encoder_padding_mask + else: + # For consistence with encoder + return x, encoder_padding_mask + + def forward_targets( + self, features: torch.Tensor, target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def get_src_lengths(self, src_lengths): + return self.feature_extractor.get_out_seq_lens_tensor(src_lengths) + + def apply_hubert_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.hubert_mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + self.conv_layers_infos = conv_layers + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + for conv in self.conv_layers: + x = conv(x) + return x + + def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i): + """Returns the out_seq_lens_nonmask 0/1 tensor after a layer. + + Args: + in_seq_lens_tensor (LongTensor): length + + Returns: + LongTensor: length + """ + out_lengths = in_seq_lens_tensor.clone() + out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() + out_nonmask = (~lengths_to_padding_mask(out_lengths)).float() + return out_nonmask, out_lengths + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for i in range(len(self.conv_layers)): + out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() + return out diff --git a/artst/models/modules/text_decoder_postnet.py b/artst/models/modules/text_decoder_postnet.py new file mode 100644 index 0000000000000000000000000000000000000000..92e244a2324816bbe78266e1ebcc120269feac97 --- /dev/null +++ b/artst/models/modules/text_decoder_postnet.py @@ -0,0 +1,92 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import torch.nn as nn +import torch +import contextlib + +from fairseq import utils +from fairseq.modules import ( + AdaptiveSoftmax, +) + +class TextDecoderPostnet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__(self, embed_tokens, dictionary, args, output_projection=None,): + super(TextDecoderPostnet, self).__init__() + self.output_embed_dim = args.decoder_output_dim + self.output_projection = output_projection + self.adaptive_softmax = None + self.share_input_output_embed = args.share_input_output_embed + if self.output_projection is None: + self.build_output_projection(args, dictionary, embed_tokens) + self.freeze_decoder_updates = args.freeze_decoder_updates + self.num_updates = 0 + + def output_layer(self, features): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + def build_output_projection(self, args, dictionary, embed_tokens): + if args.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + embed_tokens.weight.shape[1], + embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 + ) + # num_base_layers = getattr(args, "base_layers", 0) + # for i in range(num_base_layers): + # self.layers.insert( + # ((i + 1) * args.decoder_layers) // (num_base_layers + 1), + # BaseLayer(args), + # ) + + def forward(self, x): + ft = self.freeze_decoder_updates <= self.num_updates + with torch.no_grad() if not ft else contextlib.ExitStack(): + return self._forward(x) + + def _forward(self, x): + # embed positions + x = self.output_layer(x) + + return x + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates diff --git a/artst/models/modules/text_decoder_prenet.py b/artst/models/modules/text_decoder_prenet.py new file mode 100644 index 0000000000000000000000000000000000000000..db8a6f70a6cc34b5ae8ae36febc4821ca89b8062 --- /dev/null +++ b/artst/models/modules/text_decoder_prenet.py @@ -0,0 +1,128 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +import torch.nn as nn +import torch +import contextlib + +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from fairseq.models.transformer import Linear #,LayerNorm +from fairseq.modules import ( + PositionalEmbedding, + FairseqDropout, + LayerNorm +) + + +class TextDecoderPrenet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__(self, embed_tokens, args): + super(TextDecoderPrenet, self).__init__() + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.decoder_layerdrop = args.decoder_layerdrop + self.num_updates = 0 + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = args.decoder_output_dim + + self.padding_idx = embed_tokens.padding_idx + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + PositionalEmbedding( + args.max_text_positions, + embed_dim, + self.padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + export = getattr(args, "export", False) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim, export=export) + else: + self.layernorm_embedding = None + + self.freeze_decoder_updates = args.freeze_decoder_updates + + def forward(self, prev_output_tokens, incremental_state=None): + ft = self.freeze_decoder_updates <= self.num_updates + with torch.no_grad() if not ft else contextlib.ExitStack(): + return self._forward(prev_output_tokens, incremental_state) + + def _forward(self, prev_output_tokens, incremental_state=None): + if prev_output_tokens.eq(self.padding_idx).any(): + x_mask = prev_output_tokens.eq(self.padding_idx) + else: + x_mask = None + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + return x, x_mask, incremental_state + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates diff --git a/artst/models/modules/text_encoder_prenet.py b/artst/models/modules/text_encoder_prenet.py new file mode 100644 index 0000000000000000000000000000000000000000..75c248736964d7d904823b7625d2e9080881917d --- /dev/null +++ b/artst/models/modules/text_encoder_prenet.py @@ -0,0 +1,44 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import torch.nn as nn + +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding + + +class TextEncoderPrenet(nn.Module): + """ + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + embed_tokens, + args, + ): + super(TextEncoderPrenet, self).__init__() + self.padding_idx = embed_tokens.padding_idx + # define encoder prenet + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if args.enc_use_scaled_pos_enc else PositionalEncoding + ) + + self.encoder_prenet = nn.Sequential( + embed_tokens, + pos_enc_class(args.encoder_embed_dim, args.transformer_enc_positional_dropout_rate, max_len=args.max_text_positions), + ) + + def forward(self, src_tokens): + return self.encoder_prenet(src_tokens), src_tokens.eq(self.padding_idx) diff --git a/artst/models/modules/transformer_layer.py b/artst/models/modules/transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b994f8664f5d8e16c457a2520b09707d842f93 --- /dev/null +++ b/artst/models/modules/transformer_layer.py @@ -0,0 +1,410 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import contextlib +from fairseq import utils +from fairseq.modules import LayerNorm +from .multihead_attention import MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + if has_relative_attention_bias: + self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + pos_bias=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + if pos_bias is not None: + pos_bias = self.norm_k(pos_bias) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + position_bias=pos_bias, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.num_updates = 0 + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + self.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0) + + self.self_attn = self.build_self_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = args.decoder_normalize_before + + export = getattr(args, "export", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = self.build_fc1( + self.embed_dim, + args.decoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + args.decoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + self.has_relative_attention_bias = has_relative_attention_bias + if self.has_relative_attention_bias: + self.norm_k = LayerNorm(self.embed_dim//args.decoder_attention_heads) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not getattr(args, "cross_self_attention", False), + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + #has_relative_attention_bias=args.has_relative_attention_bias, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + pos_bias=None, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + ft = self.freeze_decoder_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if pos_bias is not None: + pos_bias = self.norm_k(pos_bias) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + with torch.no_grad() if not ft else contextlib.ExitStack(): + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, self_attn_state + return x, attn, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self.num_updates = num_updates diff --git a/artst/models/t5_transformer_lm.py b/artst/models/t5_transformer_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4450b088bc42c5af73daed2bf8a9b59071b81a --- /dev/null +++ b/artst/models/t5_transformer_lm.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +from fairseq.models import ( + register_model_architecture, +) +from fairseq.models.transformer_lm import base_lm_architecture + + +# @register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5") +def transformer_lm_t5(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144) + args.decoder_layers = getattr(args, "decoder_layers", 20) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) diff --git a/artst/sequence_generator.py b/artst/sequence_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6a335fc68a79d8b43060d364f31741ec76aed6c0 --- /dev/null +++ b/artst/sequence_generator.py @@ -0,0 +1,1080 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import math +from typing import Dict, List, Optional +import sys + +import torch +import torch.nn as nn +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from torch import Tensor +from fairseq.ngram_repeat_block import NGramRepeatBlock +from espnet.nets.ctc_prefix_score import CTCPrefixScore +import numpy + +CTC_SCORING_RATIO = 7.0 +device = "cuda" if torch.cuda.is_available() else "cpu" + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + ctc_weight=0.0, + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.blank = self.tgt_dict.index("") + self.mask = self.tgt_dict.index("") + self.mask_idxs = [] + if self.tgt_dict.index("0") != self.unk: + count = 0 + while self.tgt_dict.index("" + str(count)) != self.unk: + self.mask_idxs.append(self.tgt_dict.index("" + str(count))) + count += 1 + self.mask_idxs = torch.tensor(self.mask_idxs) + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + self.ctc_weight = ctc_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + encoder_outs = self.model.forward_encoder(net_input) + + # Get CTC lprobs and prep ctc_scorer + if self.ctc_weight > 0: + ctc_lprobs = self.model.models[0].get_normalized_probs_for_ctc( + encoder_outs[0], log_probs=True + ).contiguous().transpose(0, 1) # (B, T, C) from the encoder + + hyp = {} + ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy) + hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() + hyp["ctc_score_prev"] = 0.0 + ctc_beam = min(ctc_lprobs.shape[-1] - self.mask_idxs.size(-1), int(beam_size * CTC_SCORING_RATIO)) + ctc_hyps = {str(self.eos): hyp} + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) + + reorder_state: Optional[Tensor] = None + ctc_state = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.ctc_weight > 0 and step != 0: + # lprobs[:, self.blank] = -math.inf # never select blank + ctc_lprobs = lprobs.clone() + ctc_lprobs[:, self.blank] = -math.inf # never select blank + if self.mask != self.unk: + ctc_lprobs[:, self.mask] = -math.inf # never select mask + if self.mask_idxs.size(0) != 0: + ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask + local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1) + for b in range(tokens.size(0)): + hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist()) + + ctc_scores, ctc_states = ctc_prefix_score( + tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"] + ) + lprobs[b] = lprobs[b] + lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy( + ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"] + ).to(device=device) + for j in range(len(local_best_ids[b])): + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {} + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j] + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j] + + # local_ctc_scores, ctc_state = ctc_scorer( + # tokens[:, : step + 1], ctc_state, part_ids + # ) + # lprobs += local_ctc_scores * self.ctc_weight + elif self.ctc_weight > 0 and step == 0: + ctc_lprobs = lprobs.clone() + ctc_lprobs[:, self.blank] = -math.inf # never select blank + if self.mask != self.unk: + ctc_lprobs[:, self.mask] = -math.inf # never select mask + if self.mask_idxs.size(0) != 0: + ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask + local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1) + for b in range(tokens.size(0)): + hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist()) + ctc_scores, ctc_states = ctc_prefix_score( + tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"] + ) + lprobs[b] = lprobs[b] + lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy( + ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"] + ).to(device=device) + for j in range(len(local_best_ids[b])): + if b == 0: + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {} + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j] + ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j] + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs[:, :probs.size(1)] += probs + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + lprobs[:, self.blank] = -math.inf # never select blank + if self.mask != self.unk: + lprobs[:, self.mask] = -math.inf # never select mask + if self.mask_idxs.size(0) != 0: + lprobs[:, self.mask_idxs] = -math.inf # never select mask + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # if self.ctc_weight > 0: + # accum_best_id = torch.gather(cand_indices, dim=1, index=active_hypos) + # ctc_state = ctc_scorer.index_select_state( + # ctc_state, accum_best_id + # ) + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = bbsz_idx // beam_size + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append( + { + "tokens": tokens_clone[i], + "score": eos_scores[i], + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size + ): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def is_t5_structure(self): + t5_structure = hasattr(self.single_model, "text_encoder_prenet") and hasattr(self.single_model, "speech_encoder_prenet") or \ + hasattr(self.single_model, "encoder_prenet") and hasattr(self.single_model, "encoder_prenet") + return t5_structure + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + elif self.is_t5_structure(): + return [model.forward_encoder_torchscript(net_input) for model in self.models] + else: + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + if self.is_t5_structure: + decoder_out = model.forward_decoder( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i] + ) + else: + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = self.extract_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/artst/tasks/__init__.py b/artst/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/artst/tasks/__pycache__/__init__.cpython-38.pyc b/artst/tasks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1763d45d6c93f12b2e42d734da1bb91fab13fb56 Binary files /dev/null and b/artst/tasks/__pycache__/__init__.cpython-38.pyc differ diff --git a/artst/tasks/__pycache__/artst.cpython-38.pyc b/artst/tasks/__pycache__/artst.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..928d8398389433a97a99bb0ec4ea3598d40ab9f1 Binary files /dev/null and b/artst/tasks/__pycache__/artst.cpython-38.pyc differ diff --git a/artst/tasks/__pycache__/speecht5.cpython-38.pyc b/artst/tasks/__pycache__/speecht5.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..634794db046d5789669eaabc882e2f207f3a8f8a Binary files /dev/null and b/artst/tasks/__pycache__/speecht5.cpython-38.pyc differ diff --git a/artst/tasks/artst.py b/artst/tasks/artst.py new file mode 100644 index 0000000000000000000000000000000000000000..137a8be461976ef0029ad91bc278045ecae79fe2 --- /dev/null +++ b/artst/tasks/artst.py @@ -0,0 +1,711 @@ +# -------------------------------------------------------- +# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) +# Github source: https://github.com/mbzuai-nlp/ArTST + +# Based on speecht5, fairseq and espnet code bases +# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet +# -------------------------------------------------------- + +import logging +import os.path as op +from argparse import Namespace +from collections import OrderedDict + +import torch +from fairseq.data import ( + Dictionary, + encoders, + PrependTokenDataset, + AppendTokenDataset, + data_utils, + StripTokenDataset, + TokenBlockDataset, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq import utils +from artst.data.multitask_dataset import MultitaskDataset +from artst.data.speech_to_text_dataset import SpeechToTextDataset +from artst.data.text_to_speech_dataset import TextToSpeechDataset +from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset +from artst.data.speech_to_class_dataset import SpeechToClassDataset +from artst.data.speech_dataset import SpeechPretrainDataset +from artst.data.text_dataset import TextPretrainDataset +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.tasks.hubert_pretraining import LabelEncoder + +logger = logging.getLogger(__name__) + +TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"] + +@register_task("artst") +class ArTSTTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--max-speech-sample-size", + default=None, + type=int, + metavar="N", + help="max speech sample size", + ) + parser.add_argument( + "--min-speech-sample-size", + default=None, + type=int, + metavar="N", + help="min speech sample size", + ) + parser.add_argument( + "--max-speech-positions", + default=4000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-text-positions", + default=450, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument( + '--t5-task', + choices=TASK_NAME, + help='task for training' + ) + parser.add_argument( + "--bpe-tokenizer", + type=str, + default=None, + help="bpe tokenizer for s2t", + ) + # Speaker Identification (SID) + parser.add_argument( + "--finetune-from-modules", + default=None, + # choices=[ + # "encoder-decoder", "encoder", "decoder", + # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID + # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID + # "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE + # "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS + # ], + help="If set, using part modules of finetune model.", + ) + parser.add_argument( + "--finetune-out-of-modules", + default=None, + # choices=[ + # "speaker_decoder_postnet", # SID + # "speech_decoder_postnet", # SE with reduction factor 1 + # ], + help="If set, remove part modules of finetune model.", + ) + # BART + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" + " per sample for dataset", + ) + parser.add_argument( + "--sample-break-mode", + default="eos", + type=str, + help="mode for breaking sentence", + ) + parser.add_argument( + "--mask", + default=0.3, + type=float, + help="fraction of words/subwords that will be masked", + ) + parser.add_argument( + "--mask-random", + default=0.1, + type=float, + help="instead of using [MASK], use random token this often", + ) + parser.add_argument( + "--insert", + default=0.0, + type=float, + help="insert this percentage of additional random tokens", + ) + parser.add_argument( + "--permute", + default=0.0, + type=float, + help="take this proportion of subwords and permute them", + ) + parser.add_argument( + "--rotate", + default=0.0, + type=float, + help="rotate this proportion of inputs", + ) + parser.add_argument( + "--poisson-lambda", + default=3.5, + type=float, + help="randomly shuffle sentences for this proportion of inputs", + ) + parser.add_argument( + "--permute-sentences", + default=0.0, + type=float, + help="shuffle this proportion of sentences in all inputs", + ) + # parser.add_argument( + # "--mask-length", + # default="span-poisson", + # type=str, + # choices=["subword", "word", "span-poisson"], + # help="mask length to choose", + # ) + parser.add_argument( + "--replace-length", + default=1, + type=int, + help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", + ) + parser.add_argument( + "--iid-noise-target", + action="store_true", + help="whether to use t5 form target", + ) + # Hubert + parser.add_argument( + "--hubert-labels", + nargs="*", + type=str, + default=['km'], + help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning", + ) + parser.add_argument( + "--hubert-label-dir", + type=str, + default=None, + help="if set, looks for labels in this directory instead", + ) + parser.add_argument( + "--sample-rate", + default=100, + type=float, + help="target sample rate. audio files will be up/down sampled to this rate", + ) + parser.add_argument( + "--label-rates", + default=-1, + type=float, + help="if set, looks for labels in this directory instead", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="if set, normalizes input to have 0 mean and unit variance", + ) + parser.add_argument( + "--enable-padding", + action="store_true", + help="pad shorter samples instead of cropping", + ) + parser.add_argument( + "--pad-audio", + action="store_true", + help="pad audio to the longest one in the batch if true", + ) + parser.add_argument( + "--random-crop", + action="store_true", + help="always crop from the beginning if false", + ) + parser.add_argument( + "--single-target", + action="store_true", + help="if set, AddTargetDatasets outputs same keys " + "as AddTargetDataset", + ) + parser.add_argument( + "--batch-ratio", + default=None, + type=str, + help="ratio of bach size for each dataset", + ) + parser.add_argument( + "--sample-ratios", + default=None, + type=str, + help="ratio of sample for each dataset", + ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.0, + help="ctc weight for inference", + ) + parser.add_argument( + "--inference-speech", + type=bool, + default=False, + help="inference for TTS", + ) + + def __init__(self, args, dicts, config): + super().__init__(args) + self.dicts = dicts + self.config = config + self.t5_task = args.t5_task + # Used for filter size + if self.t5_task in ['s2t', 't2s', 's2s', 's2c']: + self.max_pos = [self.args.max_speech_positions * 256] + elif self.t5_task == 'pretrain': + self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions] + + self.mask_idx = self.dicts["text"].add_symbol("") + # add blank token for ctc + # if args.ctc_weight > 0: + self.blank_symbol_idx = self.dicts["text"].add_symbol("") + self.blank_symbol = "" + + # add mask token + if hasattr(args, "iid_noise_target") and args.iid_noise_target: + self.uni_mask_idxs = [] + for i in range(600): + self.uni_mask_idxs.append(self.dicts["text"].add_symbol("" + str(i))) + self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs) + + self.seed = args.seed + + @classmethod + def setup_task(cls, args, **kwargs): + # load dictionaries and config + dicts = OrderedDict() + if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"): + args.shuffle_instance = False + + # Prepare config + config = None + logger.info('No config file for ' + args.t5_task) + + if args.t5_task == "pretrain": + dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels] + dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) + else: + if config is None: + dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) + else: + dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename)) + + return cls(args, dicts, config) + + def build_criterion(self, args): + from fairseq import criterions + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + sample_ratios = [] + if self.t5_task == "s2t": + ## For speech to text task + bpe_tokenizer = self.build_bpe(self.args) + manifest = f"{self.args.data}/{split}.tsv" + procs = [LabelEncoder(self.dicts["text"])] + paths = [f"{self.args.hubert_label_dir}/{split}.txt"] + # Hawau: view dataset... + logger.info(f"Manifest: {manifest}") + # logger.info(f"Paths: {paths}") + self.datasets[split] = SpeechToTextDataset( + manifest, + sample_rate=self.args.sample_rate, + label_paths=paths, + label_processors=procs, + max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, + min_keep_sample_size=self.args.min_speech_sample_size, + normalize=self.args.normalize, + store_labels=False, + tgt_dict=self.dicts["text"], + tokenizer=bpe_tokenizer, + ) + elif self.t5_task == "t2s": + ## For text to speech task + from fairseq.data import ConcatDataset + bpe_tokenizer = self.build_bpe(self.args) + procs = [LabelEncoder(self.dicts["text"])] + t2s_datasets = [ + TextToSpeechDataset( + manifest_path=f"{self.args.data}/{name}.tsv", + sample_rate=self.args.sample_rate, + label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"], + label_processors=procs, + max_keep_sample_size=self.max_pos[0], + normalize=self.args.normalize, + store_labels=False, + src_dict=self.dicts["text"], + tokenizer=bpe_tokenizer, + reduction_factor=self.args.reduction_factor, + inference=self.args.inference_speech, + ) + for name in split.split(",") + ] + self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0] + elif self.t5_task == "s2s": + manifest = f"{self.args.data}/{split}.tsv" + self.datasets[split] = SpeechToSpeechDataset( + manifest_path=manifest, + sample_rate=self.args.sample_rate, + max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, + min_keep_sample_size=self.args.min_speech_sample_size, + normalize=self.args.normalize, + reduction_factor=self.args.reduction_factor, + ) + elif self.t5_task == "s2c": + is_train_split = ("train" in split) + is_valid_split = ("valid" in split) + if is_train_split: + max_length = 51200 + elif is_valid_split: + max_length = 76800 + else: + max_length = 2560000 + manifest = op.join(f"{self.args.data}", f"{split}.tsv") + procs = LabelEncoder(self.dicts["text"]) # map speaker to id + self.datasets[split] = SpeechToClassDataset( + manifest_path=manifest, + sample_rate=self.args.sample_rate, + label_processors=procs, + max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, + min_keep_sample_size=self.args.min_speech_sample_size, + normalize=self.args.normalize, + tgt_dict=self.dicts["text"], + max_length=max_length + ) + elif self.t5_task == "pretrain": + is_train_split = ("train" in split) + pretrain_datasets = [] + speech_split, text_split = split.split('|') + + ## Speech pre-train + manifest = f"{self.args.data}/{speech_split}.tsv" + dicts = self.dicts["hubert"] + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [ + f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels + ] + # hubert v1: pad_audio=True, random_crop=False; + self.args.dec_weight = getattr(self.args, "dec_weight", 1.0) + pretrain_datasets.append( + SpeechPretrainDataset( + manifest, + sample_rate=self.args.sample_rate, + label_paths=paths, + label_rates=self.args.label_rates, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=None, + min_keep_sample_size=32000, + max_sample_size=self.args.max_speech_sample_size, + pad_audio=self.args.pad_audio, + normalize=self.args.normalize, + store_labels=False, + random_crop=self.args.random_crop, + single_target=self.args.single_target, + reduction_factor=self.args.reduction_factor, + ) + ) + sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))])) + + ## Text pre-train + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + print(f"Loading {text_split} from data_path={data_path}") + split_path = op.join(data_path, text_split) + print(f"split_path={split_path}") + bart_dataset = data_utils.load_indexed_dataset( + split_path, + self.dicts["text"], + self.args.dataset_impl, + combine=combine, + ) + if bart_dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(text_split, split_path) + ) + bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos()) + bart_dataset = maybe_shorten_dataset( + bart_dataset, + text_split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + # create continuous blocks of tokens + bart_dataset = TokenBlockDataset( + bart_dataset, + bart_dataset.sizes, + self.args.tokens_per_sample - 2, # one less for and one for + pad=self.dicts["text"].pad(), + eos=self.dicts["text"].eos(), + break_mode=self.args.sample_break_mode, + document_sep_len=0, + ) + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos()) + bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos()) + mask_whole_words = ( + get_whole_word_mask(self.args, self.dicts["text"]) + if self.args.mask_length != "subword" + else None + ) + self.args.bert_weight = getattr(self.args, "bert_weight", 0.0) + pretrain_datasets.append( + TextPretrainDataset( + bart_dataset, + bart_dataset.sizes, + self.dicts["text"], + self.mask_idx, + mask_whole_words, + shuffle=self.args.shuffle_instance, + seed=self.seed, + args=self.args, + iid_noise_target=self.args.iid_noise_target, + uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None, + ) + ) + sample_ratios.append(sum(pretrain_datasets[1].sizes)) + logger.info( + "Task: {0}, Loaded {1} samples of denoising_dataset".format( + 'bart', + len(pretrain_datasets[1]), + ) + ) + + logger.info('token ratio is ' + str(sample_ratios)) + if self.args.batch_ratio is not None: + batch_ratio = eval(self.args.batch_ratio) + assert len(batch_ratio) == len(sample_ratios) + sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))] + else: + batch_ratio = None + max_size = max(sample_ratios) + sample_ratios = [max_size / r for r in sample_ratios] + if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None: + sample_ratios = eval(self.args.sample_ratios) + if is_train_split: + self.datasets[split] = MultitaskDataset( + pretrain_datasets, sample_ratios, batch_ratio + ) + else: + self.datasets[split] = MultitaskDataset( + pretrain_datasets, batch_ratio=batch_ratio + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + model.set_num_updates(update_num) + + # Junyi: not use sample_size, but normalize the loss locally + agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {} + agg_logging_output['sample_size'] = 1 + + def forward_backward(model, samples, weight=1.0): + nonlocal agg_loss, agg_logging_output + if samples is None or len(samples) == 0: + return + loss, sample_size, logging_output = criterion(model, samples) + if ignore_grad: + loss *= 0 + else: + loss *= weight + loss = loss / sample_size + optimizer.backward(loss) + agg_loss += loss.detach().item() + # # TODO make summing of the sample sizes configurable + for k in logging_output: + if k == 'ntokens' or k == 'nsentences': + if k not in agg_logging_output: + agg_logging_output[k] = 0 + agg_logging_output[k] += logging_output[k] + # continue + # agg_logging_output[k] += logging_output[k] + # agg_logging_output[task_name] += logging_output[k] + agg_logging_output[samples['task_name']] = logging_output + + forward_backward(model, sample) + + agg_logging_output["loss"] = agg_loss + + return agg_loss, agg_sample_size, agg_logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + from collections import defaultdict + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float) + agg_logging_output['sample_size'] = 1 + loss, sample_size, logging_output = criterion(model, sample) + loss = loss / sample_size + # agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss + agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss + agg_logging_output[sample['task_name']] = logging_output + agg_logging_output["loss"] = agg_loss + return agg_loss, agg_sample_size, agg_logging_output + + @property + def target_dictionary(self): + return self.dicts["text"] + + @property + def source_dictionary(self): + return None + + def build_model(self, args): + try: + args.input_feat_per_channel = self.config.input_feat_per_channel + args.input_channels = self.config.input_channels + except Exception as e: + args.input_feat_per_channel = 80 + args.input_channels = 1 + logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ") + logger.warn(e) + logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}") + + args.speech_odim = args.input_feat_per_channel * args.input_channels + + args.label_rates = self.args.label_rates + args.sample_rate = self.args.sample_rate + self.args.reduction_factor = args.reduction_factor + return super(ArTSTTask, self).build_model(args) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + from artst.sequence_generator import SequenceGenerator + extra_gen_cls_kwargs = { + "ctc_weight": self.args.ctc_weight, + **extra_gen_cls_kwargs + } + return super().build_generator( + models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def build_tokenizer(self, args): + if self.config is None: + logger.info(f"pre-tokenizer: None") + return encoders.build_tokenizer(Namespace(**{"tokenizer": None})) + else: + logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer)) + + def build_bpe(self, args): + if self.config is not None: + logger.info(f"tokenizer: {self.config.bpe_tokenizer}") + return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer)) + else: + logger.info(f"tokenizer: {self.args.bpe_tokenizer}") + return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer})) + + def generate_class(self, models, net_input, prefix_tokens, **kwargs): + with torch.no_grad(): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" + } + encoder_input.update(kwargs) + encoder_input.update({"prev_output_tokens": prefix_tokens}) + return models[0].generate_class(**encoder_input) + + def generate_speech(self, models, net_input, **kwargs): + with torch.no_grad(): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" + } + encoder_input.update(kwargs) + return models[0].generate_speech(**encoder_input) + + def inference_t2s( + self, models, sample + ): + with torch.no_grad(): + xs = sample['net_input']['src_tokens'] + spkemb = sample['net_input']['spkembs'] + return models[0].inference(xs, spkemb) + + def inference_s2s( + self, models, sample, force_equal_length=False + ): + with torch.no_grad(): + x = sample['net_input']['src_tokens'] + xlen = sample['net_input']['src_lengths'] + spkemb = sample['net_input']['spkembs'] + prev_output_tokens = sample['net_input']['prev_output_tokens'] + padding_mask = sample['net_input']['padding_mask'] + tgt_lengths = sample['net_input']['tgt_lengths'] + return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask) + + def inference_s2c( + self, models, sample + ): + with torch.no_grad(): + x = sample['net_input']['src_tokens'] + xlen = sample['net_input']['src_lengths'] + prev_output_tokens = sample['net_input']['prev_output_tokens'] + padding_mask = sample['net_input']['padding_mask'] + assert prev_output_tokens.size(1) == 1, prev_output_tokens.size() + return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask) + + def filter_indices_by_size( + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + ): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + + indices, ignored = dataset.filter_indices_by_size( + indices, + self.max_pos + ) + return indices diff --git a/ckpts/clartts_tts.pt b/ckpts/clartts_tts.pt new file mode 100644 index 0000000000000000000000000000000000000000..638d68e7e5709c26a5ac995e5de2e17a7dea1a87 --- /dev/null +++ b/ckpts/clartts_tts.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4dc9ed84bccb7227fc467f7bd415ee26af0af3fd82bf7715a3166a3432649812 +size 1847114585 diff --git a/embs/clartts.npy b/embs/clartts.npy new file mode 100644 index 0000000000000000000000000000000000000000..497f5805c1c878bbe4661ca321b4b796b2008172 --- /dev/null +++ b/embs/clartts.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b61cf884ce26273202d369db6bbd7d504376c4dca287979ce702fd0b108c30f +size 2176 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..59ad5d04a1ec7db04a54a3a84eddfc16f2352773 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +cython==0.29.35 +fairseq==0.12.2 +datasets==2.12.0 +editdistance==0.6.2 +espnet==202304 +espnet-tts-frontend==0.0.3 +librosa==0.9.2 +omegaconf==2.0.6 +pandas==2.0.1 +PyArabic==0.6.15 +scipy +soundfile +tqdm==4.65.0 +tweepy==4.14.0 +tensorboard +kaldiio==2.18.0 +numpy==1.23.5 +cmake==3.26.4 +pillow==10.0.0 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +tensorboardx==2.6 +transformers +speechbrain \ No newline at end of file