herwoww's picture
first upload
1547a56
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
from ast import literal_eval
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from .modules.text_encoder_prenet import TextEncoderPrenet
from .modules.text_decoder_prenet import TextDecoderPrenet
from .modules.text_decoder_postnet import TextDecoderPostnet
from .modules.speech_encoder_prenet import SpeechEncoderPrenet
from .modules.speech_encoder_postnet import SpeechEncoderPostnet
from .modules.speech_decoder_prenet import SpeechDecoderPrenet
from .modules.speech_decoder_postnet import SpeechDecoderPostnet
from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
from .modules.encoder import TransformerEncoder
from .modules.decoder import TransformerDecoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.models.transformer import Embedding
from fairseq.modules import (
GumbelVectorQuantizer,
)
from torch import Tensor
logger = logging.getLogger(__name__)
DEFAULT_MAX_TEXT_POSITIONS = 450
DEFAULT_MAX_SPEECH_POSITIONS = 4000
@register_model("artst_transformer")
class ArTSTTransformerModel(FairseqEncoderDecoderModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(
self,
args,
encoder, decoder,
text_encoder_prenet, speech_encoder_prenet,
text_decoder_prenet, speech_decoder_prenet,
text_decoder_postnet, speech_decoder_postnet,
speaker_decoder_postnet, speech_encoder_postnet,
):
super().__init__(encoder, decoder)
self.encoder = encoder
self.decoder = decoder
self.text_encoder_prenet = text_encoder_prenet
self.speech_encoder_prenet = speech_encoder_prenet
self.text_decoder_prenet = text_decoder_prenet
self.speech_decoder_prenet = speech_decoder_prenet
self.text_decoder_postnet = text_decoder_postnet
self.speech_decoder_postnet = speech_decoder_postnet
self.speaker_decoder_postnet = speaker_decoder_postnet
self.hubert_layer = speech_encoder_postnet
self.reduction_factor = args.reduction_factor
self.spk_embed_dim = args.spk_embed_dim
# define projection layer
self.spk_embed_integration_type = args.spk_embed_integration_type
if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
else:
self.projection = torch.nn.Linear(
args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
)
# Hawau: here we can add language embedding integration
self.use_codebook = args.use_codebook
self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
if self.use_codebook:
vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
self.quantizer = GumbelVectorQuantizer(
dim=args.encoder_embed_dim,
num_vars=args.latent_vars,
temp=args.latent_temp,
groups=args.latent_groups,
combine_groups=False,
vq_dim=vq_dim,
time_first=True,
weight_proj_depth=args.quantizer_depth,
weight_proj_factor=args.quantizer_factor,
)
self.num_updates = 0
# # Follow BERT's random weight initialization (for BART)
if args.bert_init:
self.apply(init_bert_params)
self.args = args
self.prune_modules(args.modules_filter)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--reduction-factor",
type=int,
help="reduction factor for decoder",
)
parser.add_argument(
"--spk-embed-dim",
type=int,
help="speaker embedding dimension",
)
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
'--freeze-encoder-updates',
type=int,
help='number of steps to freeze encoder before finetune'
)
parser.add_argument(
'--freeze-decoder-updates',
type=int,
help='number of steps to freeze decoder before finetune'
)
parser.add_argument(
'--no-freeze-encoder-layer',
type=str,
help='which encoder layer not freeze during finetune'
)
parser.add_argument(
"--share-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--share-ctc-embed",
action="store_true",
help="share ctc embed and decoder embed",
)
parser.add_argument(
"--encoder-sliding-window-attn",
default=None,
type=int,
help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
)
# Convolutional subsampler
parser.add_argument(
"--encoder-speech-prenet",
default="conv",
type=str,
choices=["conv", "linear"],
help="The type of encoder speech prenet, e.g., conv or linear."
)
parser.add_argument(
"--conv-kernel-sizes",
default="5,5",
type=str,
help="The layer of convolution of encoder speech prenet."
)
parser.add_argument(
"--conv-channels",
default=1024,
type=int,
help="The channels of encoder speech prenet."
)
parser.add_argument(
"--subsample-stride",
default="2,2",
type=str,
help="The subsample stride for conv1dsubsample."
)
parser.add_argument(
"--spk-embed-integration-type",
type=str,
choices=["pre", "add"],
help="speaker embedding integration type"
)
parser.add_argument(
"--dprenet-dropout-rate",
default=0.5,
type=float,
help="The dropout rate of decoder speech prenet."
)
## SE
parser.add_argument(
"--se-predict",
default=None,
choices=["masking", "target", "delta"],
help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
+ "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
)
parser.add_argument(
"--se-decoder-input",
type=str,
default="previous_target",
choices=["previous_target", "source"],
)
## SID
parser.add_argument(
"--modules-filter",
default=None,
type=str,
help="Remove unused modules for, e.g., SID.",
)
parser.add_argument(
"--sid-pad-prenet",
action="store_true",
help="If set, the size of text dictionary is as small as for <pad> token.",
)
parser.add_argument(
"--encoder-attn-branch",
type=str,
default="identity,full",
help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
)
parser.add_argument(
"--encoder-block-branch",
type=str,
help="average the output of encoder, e.g., '4,5,6'",
)
parser.add_argument(
"--sid-encoder-cls",
default=None,
choices=["encoder"],
help="If set, add cls vector to the encoder input, e.g., constant vector.",
)
parser.add_argument(
"--sid-shuffle-encoder-input",
action="store_true",
help="If set, shuffle encoder input in time.",
)
parser.add_argument(
"--sid-decoder-speaker",
action="store_true",
help="If set, apply speaker decoder as transformer decoder.",
)
parser.add_argument(
"--sid-decoder-attn-dim",
default=128,
type=int,
help="Attention dimension in attensive statistics pooling of speaker decoder.",
)
parser.add_argument(
"--sid-t5-postnet",
action="store_true",
help="If set, apply TextDecoderPostnet as speaker classification.",
)
parser.add_argument(
"--sid-embed-dim",
default=128,
type=int,
help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
)
parser.add_argument(
"--sid-pooling-layer",
default="decoder",
type=str,
choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
)
parser.add_argument(
"--sid-no-pooling-bn",
action="store_true",
help="If set, not attention batchnorm.",
)
parser.add_argument(
"--sid-no-embed-postnet",
action="store_true",
help="If set, no layer between decoder output and classification layer.",
)
parser.add_argument(
"--sid-normalize-postnet",
action="store_true",
help="If set, normalize input and weight in postnet/classifier.",
)
parser.add_argument(
"--sid-softmax-type",
default="softmax",
choices=["softmax", "amsoftmax", "aamsoftmax"],
help="If using amsoftmax or aamsoftmax, the target should be given.",
)
parser.add_argument(
"--softmax-scale",
default=1.0,
type=float,
help="Scale for AMSoftmax or AAMSoftmax.",
)
parser.add_argument(
"--softmax-margin",
default=0.0,
type=float,
help="Margin for AMSoftmax or AAMSoftmax.",
)
parser.add_argument(
"--softmax-easy-margin",
action="store_true",
help="Enable easy margin for AAMSoftmax.",
)
parser.add_argument(
"--encoder-layerdrop",
type=float,
metavar="D",
help="LayerDrop probability for encoder",
)
parser.add_argument(
"--decoder-layerdrop",
type=float,
metavar="D",
help="LayerDrop probability for decoder",
)
## Hubert
parser.add_argument(
'--feature-grad-mult',
type=float,
help='multiply feature extractor var grads by this'
)
parser.add_argument(
'--logit-temp',
type=float,
help='temperature to divide logits by'
)
parser.add_argument(
'--final-dim',
type=int,
help="project final representations and targets to this many "
"dimensions. set to encoder_embed_dim is <= 0"
)
# mask
parser.add_argument(
'--hubert-mask-length',
type=int,
help='mask length'
)
parser.add_argument(
'--mask-prob',
type=float,
help='probability of replacing a token with mask'
)
parser.add_argument(
"--mask-selection",
choices=["static", "uniform", "normal", "poisson"],
help="how to choose mask length",
)
parser.add_argument(
'--mask-other',
type=float,
help="secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indices"
)
parser.add_argument(
'--mask-min-space',
type=int,
help='min space between spans (if no overlap is enabled)'
)
# channel masking
parser.add_argument(
'--mask-channel-length',
type=int,
help='length of the mask for features (channels)'
)
parser.add_argument(
'--mask-channel-prob',
type=float,
help="probability of replacing a feature with 0"
)
parser.add_argument(
"--mask-channel-selection",
choices=["static", "uniform", "normal", "poisson"],
help="how to choose mask length for channel masking",
)
parser.add_argument(
'--mask-channel-other',
type=float,
help="secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indices"
)
parser.add_argument(
'--mask-channel-min-space',
type=int,
help='min space between spans (if no overlap is enabled)'
)
# abs positional embeddings
parser.add_argument(
'--conv-pos',
type=int,
help='number of filters for convolutional positional embeddings'
)
parser.add_argument(
'--conv-pos-groups',
type=int,
help='number of groups for convolutional positional embedding'
)
# codebook related
parser.add_argument(
"--use-codebook",
action="store_true",
help="whether to use codebook",
)
parser.add_argument(
"--codebook-prob",
type=float,
help="probability to use codebook",
)
parser.add_argument(
"--latent-vars",
type=int,
help="number of latent variables V in each group of the codebook",
)
parser.add_argument(
"--latent-groups",
type=int,
help="number of groups G of latent variables in the codebook",
)
parser.add_argument(
"--latent-dim",
type=int,
help="if > 0, uses this dimensionality for latent variables. "
"otherwise uses final_dim / latent_groups",
)
parser.add_argument(
"--latent-temp",
type=literal_eval,
help="temperature for latent variable sampling. "
"can be tuple of 3 values (start, end, decay)",
)
parser.add_argument(
"--quantizer-depth",
type=int,
help="number of quantizer layers",
)
parser.add_argument(
"--quantizer-factor",
type=int,
help="number of quantizer layers",
)
parser.add_argument(
"--get-code-distribution",
action='store_true',
help="whether to get the code distribution (for test)",
)
# relative pos enc
parser.add_argument(
"--relative-position-embedding",
action='store_true',
help="whether to use relative position embedding",
)
parser.add_argument(
"--num-buckets",
type=int,
default=320,
help="num of buckets for relative position embedding",
)
parser.add_argument(
"--max-distance",
type=int,
default=1280,
help="max distance for relative position embedding",
)
parser.add_argument(
"--encoder-max-relative-position",
type=int,
help="max distance for relative position embedding in encoder",
)
parser.add_argument(
"--decoder-max-relative-position",
type=int,
help="max distance for relative position embedding in decoder",
)
# hubert feature extractor
parser.add_argument(
"--conv-feature-layers",
type=str,
help= "string describing convolutional feature extraction "
"layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]",
)
parser.add_argument(
"--conv-bias",
action='store_true',
help="include bias in conv encoder",
)
parser.add_argument(
"--extractor-mode",
choices=["default", "layer_norm"],
help="mode for feature extractor. default has a single group "
"norm with d groups in the first conv block, whereas layer_norm "
"has layer norms in every block (meant to use with normalize=True)"
)
# others
parser.add_argument(
"--bert-init",
action='store_true',
help="initilize as bert",
)
parser.add_argument(
"--unb-enc-layer",
type=int,
default=-1,
help="which layer's output is used as the input of decoder",
)
# Encoder, Decoder
@classmethod
def build_encoder(cls, args, dictionary=None, embed_tokens=None):
return TransformerEncoder(args, dictionary, embed_tokens)
@classmethod
def build_decoder(cls, args):
return TransformerDecoder(args)
# Encoder Prenet
@classmethod
def build_text_encoder_prenet(cls, embed_tokens, args):
return TextEncoderPrenet(embed_tokens, args)
@classmethod
def build_speech_encoder_prenet(cls, args):
return SpeechEncoderPrenet(args)
# Decoder Prenet
@classmethod
def build_text_decoder_prenet(cls, embed_tokens, args):
return TextDecoderPrenet(embed_tokens, args)
@classmethod
def build_speech_decoder_prenet(cls, odim, args):
return SpeechDecoderPrenet(odim, args)
# Decoder Postnet
@classmethod
def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
return TextDecoderPostnet(embed_tokens, dictionary, args)
@classmethod
def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
return SpeakerDecoderPostnet(embed_dim, class_num, args)
@classmethod
def build_speech_decoder_postnet(cls, odim, args):
return SpeechDecoderPostnet(odim, args)
@classmethod
def build_speech_encoder_postnet(cls, dictionaries, args):
return SpeechEncoderPostnet(dictionaries, args)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
num_embeddings = len(dictionary)
if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
num_embeddings = min(num_embeddings, max_num_embeddings)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
max_num_embeddings = 3 # <pad> at index 2
else:
max_num_embeddings = None
text_decoder_embed_tokens = build_embedding(
task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
)
if args.share_input_output_embed:
text_encoder_embed_tokens = text_decoder_embed_tokens
else:
text_encoder_embed_tokens = build_embedding(
task.dicts["text"], args.encoder_embed_dim
)
speech_odim = args.speech_odim
if "text" in task.dicts:
encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
else:
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args)
text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
speech_encoder_prenet = cls.build_speech_encoder_prenet(args)
text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
if getattr(args, "sid_pooling_layer", None) == "decoder-las":
speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
else:
speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)
text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)
if getattr(args, "sid_t5_postnet", False):
speaker_decoder_postnet = None
else:
if task.t5_task == "s2c":
speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
else:
speaker_decoder_postnet = None
if "hubert" in task.dicts:
speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
else:
speech_encoder_postnet = None
return cls(
args,
encoder, decoder,
text_encoder_prenet, speech_encoder_prenet,
text_decoder_prenet, speech_decoder_prenet,
text_decoder_postnet, speech_decoder_postnet,
speaker_decoder_postnet, speech_encoder_postnet,
)
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def get_normalized_probs_for_ctc(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out_for_ctc"][0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, sample, net_output, is_masked=True):
if "logit_m_list" in net_output:
logits_list = self.get_logits(net_output, is_masked)
targets_list = [
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
]
return targets_list
else:
return sample["target"]
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
if "prob_perplexity" in net_output:
extra_losses.append(
(net_output["num_vars"] - net_output["prob_perplexity"])
/ net_output["num_vars"]
)
names.append("prob_perplexity")
return extra_losses, names
def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
assert source is not None or src_tokens is not None
# padding_mask is not none only when input is waveform
if source is None and padding_mask is None and not feature_only:
input_type = 'text'
else:
input_type = 'speech'
if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
output_type = 'text'
codebook_out = {}
else:
output_type = 'speech'
if task_name is not None and task_name == "s2c":
if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
else:
sid_target = None
target_list = None
# Encoder Prenet
if input_type == 'text':
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
else:
if target_list is not None:
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
encoder_input, features_pen, mask_indices, target_list = encoder_input
else:
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
# shuffle a batch of inputs of encoder
if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
if getattr(self.args, "sid_encoder_cls", None) == "encoder":
prev_output_tokens = torch.zeros_like(prev_output_tokens)
encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)
# Encoder: T x B x C
encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)
if task_name is not None and task_name == 'speech_pretrain' and feature_only:
return encoder_output["encoder_out"][0].transpose(0, 1)
if task_name is not None and task_name == 's2c':
if self.args.sid_pooling_layer == "encoder":
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
elif self.args.sid_pooling_layer == "encoder-cls":
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None
if target_list is not None:
hubert_results = self.hubert_layer(
encoder_output["encoder_out"][0].transpose(0, 1),
encoder_padding_mask,
mask_indices,
target_list
)
hubert_results['features_pen'] = features_pen
if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
# Change the encoder output to decoder input once set unb-enc-layer
encoder_output["encoder_out"] = encoder_output["decoder_input"]
if self.use_codebook:
q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))
# q["x"]: B x T x C
# Sample indexs according to the codebook prob
random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
# Make weight for q
q_w = q["x"].new_zeros(q["x"].size(1))
q_w[random_idx] = 1.0
# Combine quantized codes and encoder output
encoder_output["encoder_out"][0] = (
q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
).transpose(0, 1)
# encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
if output_type == 'speech':
hubert_results["prob_perplexity"] = q["prob_perplexity"]
hubert_results["code_perplexity"] = q["code_perplexity"]
hubert_results["num_vars"] = q["num_vars"]
hubert_results["temp"] = q["temp"]
elif output_type == 'text':
codebook_out["prob_perplexity"] = q["prob_perplexity"]
codebook_out["code_perplexity"] = q["code_perplexity"]
codebook_out["num_vars"] = q["num_vars"]
codebook_out["temp"] = q["temp"]
if only_hubert and target_list is not None:
return hubert_results, None
if only_ctc and task_name is not None and task_name == "s2t":
return None, encoder_output
elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
return encoder_output
# Decoder Prenet
if output_type == 'text':
# _ is the incremental state
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
if task_name is not None and task_name == 's2c':
prev_output_tokens = torch.zeros_like(prev_output_tokens)
else:
# integrate speaker embedding
if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
# Decoder Prenet
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
else:
if self.spk_embed_dim is not None:
encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
encoder_output["encoder_out"][0].transpose(0, 1), spkembs
).transpose(0, 1)]
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)
# BART Sequence Classification: cat <pad> + feature before decoder
if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
# SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
# Decoder
decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output,
full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False),
alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
# Decoder Postnet
if task_name is not None and task_name == 's2c':
if not getattr(self.args, "sid_t5_postnet", False):
if self.args.sid_pooling_layer == "decoder":
return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
elif self.args.sid_pooling_layer == "decoder-las":
indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
else:
return (self.text_decoder_postnet(decoder_output), None), encoder_output
# SE predict: masking, target, delta. Ensure reduction factor 1
if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
se_predict = getattr(self.args, "se_predict")
if se_predict == "masking":
before_outs = torch.sigmoid(before_outs) * src_tokens
after_outs = torch.sigmoid(after_outs) * src_tokens
return before_outs, after_outs, logits, extra['attn'][0]
elif se_predict == "target":
return before_outs, after_outs, logits, extra['attn'][0]
elif se_predict == "delta":
before_outs = before_outs - src_tokens
after_outs = after_outs - src_tokens
return before_outs, after_outs, logits, extra['attn'][0]
else:
raise ValueError(f"{se_predict} not in [masking, target, delta]")
if task_name is not None and task_name == 's2t':
#return self.text_decoder_postnet(decoder_output), None
return (self.text_decoder_postnet(decoder_output), None), encoder_output
if output_type == 'text':
return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
else:
if target_list is not None:
return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
else:
return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)
def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
"""
encoder_input: [B, T, C]
encoder_padding_mask: [B, T]
"""
if hasattr(self, "text_decoder_prenet"):
if isinstance(pad_input, tuple):
repeat_cls_vector, repeat_cls_mask = pad_input
else:
repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)
if encoder_padding_mask is not None:
bsz = encoder_input.size(0)
tsz = encoder_input.size(1)
encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
if repeat_cls_mask is None:
mask_size = (encoder_padding_mask.size(0), 1)
mask_type = encoder_padding_mask.dtype
repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)
if cls_first:
ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
else:
ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
mask_size = (encoder_padding_mask.size(0), 1)
mask_type = encoder_padding_mask.dtype
repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
+ repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
return ret_encoder_input, ret_encoder_padding_mask
def _integrate_with_spk_embed(self, hs, spembs):
"""Integrate speaker embedding with hidden states.
Args:
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spembs = self.projection(F.normalize(spembs))
hs = hs + spembs.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = self.projection(torch.cat([hs, spembs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args=None,
):
"""NOT STRICT Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
# self.prune_modules(model_cfg.modules_filter)
model_dict_size = self.text_decoder_postnet.output_projection.out_features
ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
if model_dict_size != ckpt_dict_size:
# reset dictionary-related modules, such as embedding table and encoder ctc embed
logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
logger.info(f"reset model dictionary with size of {model_dict_size}")
removed_keys = [
key for key in state_dict.keys() if any(
key.startswith(previ) for previ in [
"encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
]
)
]
for key in removed_keys:
state_dict.pop(key, None)
logger.info(f"removed loaded checkpoint: {key}")
for m in self._modules.keys():
m_state_dict = {
key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
}
if hasattr(self, m):
self._modules[m].load_state_dict(m_state_dict, False)
return self
def prune_modules(self, modules_filter=None):
"""Prune unused modules for specific tasks."""
if modules_filter is None:
return
elif modules_filter == "s2c":
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las":
del self.speech_decoder_prenet
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
if hasattr(self.encoder, "proj"): self.encoder.proj = None
if hasattr(self, "projection"): del self.projection
if hasattr(self, "quantizer"): del self.quantizer
if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False):
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
if hasattr(self.decoder, "layers"): del self.decoder.layers
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
elif modules_filter == "s2s":
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
if hasattr(self.encoder, "proj"): self.encoder.proj = None
if hasattr(self, "projection"): del self.projection
if hasattr(self, "quantizer"): del self.quantizer
elif modules_filter == "t2s":
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
if hasattr(self.encoder, "proj"): self.encoder.proj = None
if hasattr(self, "projection"): del self.projection
if hasattr(self, "quantizer"): del self.quantizer
elif modules_filter == "s3prl":
# remain the encoder and the pre/post net
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
if hasattr(self.decoder, "layers"): del self.decoder.layers
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
if hasattr(self.encoder, "proj"): self.encoder.proj = None
if hasattr(self, "projection"): del self.projection
if hasattr(self, "quantizer"): del self.quantizer
def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward_encoder(
source=net_input["source"],
padding_mask=net_input["padding_mask"]
)
else:
return self.forward_encoder_non_torchscript(net_input)
@torch.jit.unused
def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
}
return self.forward_encoder(**encoder_input)
def forward_encoder(self, source, padding_mask=None):
# Encoder Prenet
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)
# Encoder
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
return encoder_output
def forward_text_encoder(self, src_tokens):
# Text Encoder Prenet
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
# Encoder
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
return encoder_output
def forward_decoder(self, tokens, encoder_out, incremental_state):
# Decoder Prenet
prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)
# Decoder
decoder_output, extra = self.decoder(
prev_output_tokens,
tgt_mask,
encoder_out=encoder_out,
incremental_state=incremental_state,
)
# Decoder Postnet
return self.text_decoder_postnet(decoder_output), extra
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def generate_class(self, source, prev_output_tokens, **kwargs):
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]
decoder_output, extra = self.decoder(
prev_output_tokens,
tgt_mask,
encoder_out=encoder_out,
)
decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))
pred_class = decoder_out.argmax(1)
return pred_class
def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
assert source is not None or src_tokens is not None
threshold = kwargs.get("threshold", 0.5)
minlenratio = kwargs.get("threshold", 0.0)
if source is None:
assert src_tokens.size(0) == 1
encoder_out = self.forward_text_encoder(src_tokens)
maxlenratio = kwargs.get("threshold", 20.0)
else:
assert source.size(0) == 1
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
maxlenratio = kwargs.get("threshold", 10.0)
if spkembs is not None and self.spk_embed_integration_type != "pre":
encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
encoder_out["encoder_out"][0].transpose(0, 1), spkembs
).transpose(0, 1)]
spkembs = None
maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
idx = 0
ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
outs, probs = [], []
# forward decoder step-by-step
if isinstance(self.decoder, FairseqIncrementalDecoder):
incremental_states = {}
else:
incremental_states = None
attns = []
while True:
# update index
idx += 1
# calculate output and stop prob at idx-th step
decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...]
probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...]
# update next inputs
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim)
attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L)
if self.speech_decoder_postnet.postnet is not None:
outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
attn = torch.cat(attns, dim=2)
break
if outs.size(0) == maxlen:
logging.warning("output length reaches maximum length")
return outs, probs, attn
@register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
def base_architecture(args):
# Transformer
args.bert_init = getattr(args, "bert_init", False)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)
# Espnet related, including prenet, postnet
args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
args.use_batch_norm = getattr(args, "use_batch_norm", True)
args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
args.postnet_layers = getattr(args, "postnet_layers", 5)
args.postnet_chans = getattr(args, "postnet_chans", 256)
args.postnet_filts = getattr(args, "postnet_filts", 5)
args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
args.dprenet_layers = getattr(args, "dprenet_layers", 2)
args.dprenet_units = getattr(args, "dprenet_units", 256)
args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
args.reduction_factor = getattr(args, "reduction_factor", 2)
args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
# Convolutional subsampler
args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)
## sid
args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
args.softmax_scale = getattr(args, "softmax_scale", 1)
args.softmax_margin = getattr(args, "softmax_margin", 0)
args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
args.modules_filter = getattr(args, "modules_filter", None)
## Hubert
args.conv_pos = getattr(args, "conv_pos", 128)
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
args.target_glu = getattr(args, "target_glu", False)
args.logit_temp = getattr(args, "logit_temp", 0.1)
args.final_dim = getattr(args, "final_dim", 256)
args.untie_final_proj = getattr(args, "untie_final_proj", True)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
# hubert feature extractor
args.extractor_mode = getattr(args, "extractor_mode", "default")
args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
args.conv_bias = getattr(args, "conv_bias", False)
# mask
args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
args.mask_prob = getattr(args, "mask_prob", 0.0)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_other = getattr(args, "mask_other", 0)
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
args.mask_min_space = getattr(args, "mask_min_space", 1)
# channel mask
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
# loss computation
args.skip_masked = getattr(args, "skip_masked", False)
args.skip_nomask = getattr(args, "skip_nomask", False)
# conv Pos
args.use_conv_pos = getattr(args, "use_conv_pos", False)
args.use_sinc_pos = getattr(args, "use_sinc_pos", False)
# codebook
args.use_codebook = getattr(args, "use_codebook", False)
args.latent_vars = getattr(args, "latent_vars", 100)
args.latent_groups = getattr(args, "latent_groups", 2)
args.latent_dim = getattr(args, "latent_dim", 0)
args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
args.quantizer_depth = getattr(args, "quantizer_depth", 1)
args.quantizer_factor = getattr(args, "quantizer_factor", 3)
args.codebook_prob = getattr(args, "codebook_prob", 0.5)
# Relative pos embed
args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
args.num_buckets = getattr(args, "num_buckets", 320)
args.max_distance = getattr(args, "max_distance", 1280)
args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)
@register_model_architecture("artst_transformer", "artst_transformer_base")
def artst_transformer_base(args):
args.use_conv_pos = getattr(args, "use_conv_pos", True)
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.layer_norm_first = getattr(args, "layer_norm_first", False)
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
args.dropout = getattr(args, "dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
args.mask_prob = getattr(args, "mask_prob", 0.80)
base_architecture(args)
@register_model_architecture("artst_transformer", "artst_transformer_large")
def artst_transformer_large(args):
args.use_conv_pos = getattr(args, "use_conv_pos", True)
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.layer_norm_first = getattr(args, "layer_norm_first", True)
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
args.dropout = getattr(args, "dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
args.final_dim = getattr(args, "final_dim", 768)
args.mask_prob = getattr(args, "mask_prob", 0.80)
base_architecture(args)
@register_model_architecture("artst_transformer", "artst_transformer_base_asr")
def artst_transformer_base_asr(args):
args.use_conv_pos = getattr(args, "use_conv_pos", True)
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.layer_norm_first = getattr(args, "layer_norm_first", False)
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
args.dropout = getattr(args, "dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
args.mask_prob = getattr(args, "mask_prob", 0.75)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_channel_length = getattr(args, "mask_channel_length", 64)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.max_text_positions = getattr(args, "max_text_positions", 600)
base_architecture(args)