Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import logging | |
from ast import literal_eval | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.models import ( | |
FairseqEncoderDecoderModel, | |
FairseqIncrementalDecoder, | |
register_model, | |
register_model_architecture, | |
) | |
from .modules.text_encoder_prenet import TextEncoderPrenet | |
from .modules.text_decoder_prenet import TextDecoderPrenet | |
from .modules.text_decoder_postnet import TextDecoderPostnet | |
from .modules.speech_encoder_prenet import SpeechEncoderPrenet | |
from .modules.speech_encoder_postnet import SpeechEncoderPostnet | |
from .modules.speech_decoder_prenet import SpeechDecoderPrenet | |
from .modules.speech_decoder_postnet import SpeechDecoderPostnet | |
from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet | |
from .modules.encoder import TransformerEncoder | |
from .modules.decoder import TransformerDecoder | |
from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
from fairseq.models.transformer import Embedding | |
from fairseq.modules import ( | |
GumbelVectorQuantizer, | |
) | |
from torch import Tensor | |
logger = logging.getLogger(__name__) | |
DEFAULT_MAX_TEXT_POSITIONS = 450 | |
DEFAULT_MAX_SPEECH_POSITIONS = 4000 | |
class ArTSTTransformerModel(FairseqEncoderDecoderModel): | |
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for | |
speech-to-text tasks. The Transformer encoder/decoder remains the same. | |
A trainable input subsampler is prepended to the Transformer encoder to | |
project inputs into the encoder dimension as well as downsample input | |
sequence for computational efficiency.""" | |
def __init__( | |
self, | |
args, | |
encoder, decoder, | |
text_encoder_prenet, speech_encoder_prenet, | |
text_decoder_prenet, speech_decoder_prenet, | |
text_decoder_postnet, speech_decoder_postnet, | |
speaker_decoder_postnet, speech_encoder_postnet, | |
): | |
super().__init__(encoder, decoder) | |
self.encoder = encoder | |
self.decoder = decoder | |
self.text_encoder_prenet = text_encoder_prenet | |
self.speech_encoder_prenet = speech_encoder_prenet | |
self.text_decoder_prenet = text_decoder_prenet | |
self.speech_decoder_prenet = speech_decoder_prenet | |
self.text_decoder_postnet = text_decoder_postnet | |
self.speech_decoder_postnet = speech_decoder_postnet | |
self.speaker_decoder_postnet = speaker_decoder_postnet | |
self.hubert_layer = speech_encoder_postnet | |
self.reduction_factor = args.reduction_factor | |
self.spk_embed_dim = args.spk_embed_dim | |
# define projection layer | |
self.spk_embed_integration_type = args.spk_embed_integration_type | |
if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre': | |
if self.spk_embed_integration_type == "add": | |
self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim) | |
else: | |
self.projection = torch.nn.Linear( | |
args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim | |
) | |
# Hawau: here we can add language embedding integration | |
self.use_codebook = args.use_codebook | |
self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob | |
if self.use_codebook: | |
vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim | |
self.quantizer = GumbelVectorQuantizer( | |
dim=args.encoder_embed_dim, | |
num_vars=args.latent_vars, | |
temp=args.latent_temp, | |
groups=args.latent_groups, | |
combine_groups=False, | |
vq_dim=vq_dim, | |
time_first=True, | |
weight_proj_depth=args.quantizer_depth, | |
weight_proj_factor=args.quantizer_factor, | |
) | |
self.num_updates = 0 | |
# # Follow BERT's random weight initialization (for BART) | |
if args.bert_init: | |
self.apply(init_bert_params) | |
self.args = args | |
self.prune_modules(args.modules_filter) | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
# Transformer | |
parser.add_argument( | |
"--activation-fn", | |
type=str, | |
choices=utils.get_available_activation_fns(), | |
help="activation function to use", | |
) | |
parser.add_argument( | |
"--dropout", type=float, metavar="D", help="dropout probability" | |
) | |
parser.add_argument( | |
"--attention-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability for attention weights", | |
) | |
parser.add_argument( | |
"--activation-dropout", | |
"--relu-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability after activation in FFN.", | |
) | |
parser.add_argument( | |
"--encoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension", | |
) | |
parser.add_argument( | |
"--encoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--encoder-layers", type=int, metavar="N", help="num encoder layers" | |
) | |
parser.add_argument( | |
"--encoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num encoder attention heads", | |
) | |
parser.add_argument( | |
"--encoder-normalize-before", | |
action="store_true", | |
help="apply layernorm before each encoder block", | |
) | |
parser.add_argument( | |
"--decoder-normalize-before", | |
action="store_true", | |
help="apply layernorm before each decoder block", | |
) | |
parser.add_argument( | |
"--decoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension", | |
) | |
parser.add_argument( | |
"--decoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--decoder-layers", type=int, metavar="N", help="num decoder layers" | |
) | |
parser.add_argument( | |
"--decoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num decoder attention heads", | |
) | |
parser.add_argument( | |
"--reduction-factor", | |
type=int, | |
help="reduction factor for decoder", | |
) | |
parser.add_argument( | |
"--spk-embed-dim", | |
type=int, | |
help="speaker embedding dimension", | |
) | |
parser.add_argument( | |
"--layernorm-embedding", | |
action="store_true", | |
help="add layernorm to embedding", | |
) | |
parser.add_argument( | |
"--load-pretrained-encoder-from", | |
type=str, | |
metavar="STR", | |
help="model to take encoder weights from (for initialization)", | |
) | |
parser.add_argument( | |
'--freeze-encoder-updates', | |
type=int, | |
help='number of steps to freeze encoder before finetune' | |
) | |
parser.add_argument( | |
'--freeze-decoder-updates', | |
type=int, | |
help='number of steps to freeze decoder before finetune' | |
) | |
parser.add_argument( | |
'--no-freeze-encoder-layer', | |
type=str, | |
help='which encoder layer not freeze during finetune' | |
) | |
parser.add_argument( | |
"--share-input-output-embed", | |
action="store_true", | |
help="share decoder input and output embeddings", | |
) | |
parser.add_argument( | |
"--share-ctc-embed", | |
action="store_true", | |
help="share ctc embed and decoder embed", | |
) | |
parser.add_argument( | |
"--encoder-sliding-window-attn", | |
default=None, | |
type=int, | |
help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20", | |
) | |
# Convolutional subsampler | |
parser.add_argument( | |
"--encoder-speech-prenet", | |
default="conv", | |
type=str, | |
choices=["conv", "linear"], | |
help="The type of encoder speech prenet, e.g., conv or linear." | |
) | |
parser.add_argument( | |
"--conv-kernel-sizes", | |
default="5,5", | |
type=str, | |
help="The layer of convolution of encoder speech prenet." | |
) | |
parser.add_argument( | |
"--conv-channels", | |
default=1024, | |
type=int, | |
help="The channels of encoder speech prenet." | |
) | |
parser.add_argument( | |
"--subsample-stride", | |
default="2,2", | |
type=str, | |
help="The subsample stride for conv1dsubsample." | |
) | |
parser.add_argument( | |
"--spk-embed-integration-type", | |
type=str, | |
choices=["pre", "add"], | |
help="speaker embedding integration type" | |
) | |
parser.add_argument( | |
"--dprenet-dropout-rate", | |
default=0.5, | |
type=float, | |
help="The dropout rate of decoder speech prenet." | |
) | |
## SE | |
parser.add_argument( | |
"--se-predict", | |
default=None, | |
choices=["masking", "target", "delta"], | |
help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs." | |
+ "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs", | |
) | |
parser.add_argument( | |
"--se-decoder-input", | |
type=str, | |
default="previous_target", | |
choices=["previous_target", "source"], | |
) | |
## SID | |
parser.add_argument( | |
"--modules-filter", | |
default=None, | |
type=str, | |
help="Remove unused modules for, e.g., SID.", | |
) | |
parser.add_argument( | |
"--sid-pad-prenet", | |
action="store_true", | |
help="If set, the size of text dictionary is as small as for <pad> token.", | |
) | |
parser.add_argument( | |
"--encoder-attn-branch", | |
type=str, | |
default="identity,full", | |
help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'", | |
) | |
parser.add_argument( | |
"--encoder-block-branch", | |
type=str, | |
help="average the output of encoder, e.g., '4,5,6'", | |
) | |
parser.add_argument( | |
"--sid-encoder-cls", | |
default=None, | |
choices=["encoder"], | |
help="If set, add cls vector to the encoder input, e.g., constant vector.", | |
) | |
parser.add_argument( | |
"--sid-shuffle-encoder-input", | |
action="store_true", | |
help="If set, shuffle encoder input in time.", | |
) | |
parser.add_argument( | |
"--sid-decoder-speaker", | |
action="store_true", | |
help="If set, apply speaker decoder as transformer decoder.", | |
) | |
parser.add_argument( | |
"--sid-decoder-attn-dim", | |
default=128, | |
type=int, | |
help="Attention dimension in attensive statistics pooling of speaker decoder.", | |
) | |
parser.add_argument( | |
"--sid-t5-postnet", | |
action="store_true", | |
help="If set, apply TextDecoderPostnet as speaker classification.", | |
) | |
parser.add_argument( | |
"--sid-embed-dim", | |
default=128, | |
type=int, | |
help="Embedding dimension in speaker postnet for speaker identification if embed postnet.", | |
) | |
parser.add_argument( | |
"--sid-pooling-layer", | |
default="decoder", | |
type=str, | |
choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"], | |
help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.", | |
) | |
parser.add_argument( | |
"--sid-no-pooling-bn", | |
action="store_true", | |
help="If set, not attention batchnorm.", | |
) | |
parser.add_argument( | |
"--sid-no-embed-postnet", | |
action="store_true", | |
help="If set, no layer between decoder output and classification layer.", | |
) | |
parser.add_argument( | |
"--sid-normalize-postnet", | |
action="store_true", | |
help="If set, normalize input and weight in postnet/classifier.", | |
) | |
parser.add_argument( | |
"--sid-softmax-type", | |
default="softmax", | |
choices=["softmax", "amsoftmax", "aamsoftmax"], | |
help="If using amsoftmax or aamsoftmax, the target should be given.", | |
) | |
parser.add_argument( | |
"--softmax-scale", | |
default=1.0, | |
type=float, | |
help="Scale for AMSoftmax or AAMSoftmax.", | |
) | |
parser.add_argument( | |
"--softmax-margin", | |
default=0.0, | |
type=float, | |
help="Margin for AMSoftmax or AAMSoftmax.", | |
) | |
parser.add_argument( | |
"--softmax-easy-margin", | |
action="store_true", | |
help="Enable easy margin for AAMSoftmax.", | |
) | |
parser.add_argument( | |
"--encoder-layerdrop", | |
type=float, | |
metavar="D", | |
help="LayerDrop probability for encoder", | |
) | |
parser.add_argument( | |
"--decoder-layerdrop", | |
type=float, | |
metavar="D", | |
help="LayerDrop probability for decoder", | |
) | |
## Hubert | |
parser.add_argument( | |
'--feature-grad-mult', | |
type=float, | |
help='multiply feature extractor var grads by this' | |
) | |
parser.add_argument( | |
'--logit-temp', | |
type=float, | |
help='temperature to divide logits by' | |
) | |
parser.add_argument( | |
'--final-dim', | |
type=int, | |
help="project final representations and targets to this many " | |
"dimensions. set to encoder_embed_dim is <= 0" | |
) | |
# mask | |
parser.add_argument( | |
'--hubert-mask-length', | |
type=int, | |
help='mask length' | |
) | |
parser.add_argument( | |
'--mask-prob', | |
type=float, | |
help='probability of replacing a token with mask' | |
) | |
parser.add_argument( | |
"--mask-selection", | |
choices=["static", "uniform", "normal", "poisson"], | |
help="how to choose mask length", | |
) | |
parser.add_argument( | |
'--mask-other', | |
type=float, | |
help="secondary mask argument " | |
"(used for more complex distributions), " | |
"see help in compute_mask_indices" | |
) | |
parser.add_argument( | |
'--mask-min-space', | |
type=int, | |
help='min space between spans (if no overlap is enabled)' | |
) | |
# channel masking | |
parser.add_argument( | |
'--mask-channel-length', | |
type=int, | |
help='length of the mask for features (channels)' | |
) | |
parser.add_argument( | |
'--mask-channel-prob', | |
type=float, | |
help="probability of replacing a feature with 0" | |
) | |
parser.add_argument( | |
"--mask-channel-selection", | |
choices=["static", "uniform", "normal", "poisson"], | |
help="how to choose mask length for channel masking", | |
) | |
parser.add_argument( | |
'--mask-channel-other', | |
type=float, | |
help="secondary mask argument " | |
"(used for more complex distributions), " | |
"see help in compute_mask_indices" | |
) | |
parser.add_argument( | |
'--mask-channel-min-space', | |
type=int, | |
help='min space between spans (if no overlap is enabled)' | |
) | |
# abs positional embeddings | |
parser.add_argument( | |
'--conv-pos', | |
type=int, | |
help='number of filters for convolutional positional embeddings' | |
) | |
parser.add_argument( | |
'--conv-pos-groups', | |
type=int, | |
help='number of groups for convolutional positional embedding' | |
) | |
# codebook related | |
parser.add_argument( | |
"--use-codebook", | |
action="store_true", | |
help="whether to use codebook", | |
) | |
parser.add_argument( | |
"--codebook-prob", | |
type=float, | |
help="probability to use codebook", | |
) | |
parser.add_argument( | |
"--latent-vars", | |
type=int, | |
help="number of latent variables V in each group of the codebook", | |
) | |
parser.add_argument( | |
"--latent-groups", | |
type=int, | |
help="number of groups G of latent variables in the codebook", | |
) | |
parser.add_argument( | |
"--latent-dim", | |
type=int, | |
help="if > 0, uses this dimensionality for latent variables. " | |
"otherwise uses final_dim / latent_groups", | |
) | |
parser.add_argument( | |
"--latent-temp", | |
type=literal_eval, | |
help="temperature for latent variable sampling. " | |
"can be tuple of 3 values (start, end, decay)", | |
) | |
parser.add_argument( | |
"--quantizer-depth", | |
type=int, | |
help="number of quantizer layers", | |
) | |
parser.add_argument( | |
"--quantizer-factor", | |
type=int, | |
help="number of quantizer layers", | |
) | |
parser.add_argument( | |
"--get-code-distribution", | |
action='store_true', | |
help="whether to get the code distribution (for test)", | |
) | |
# relative pos enc | |
parser.add_argument( | |
"--relative-position-embedding", | |
action='store_true', | |
help="whether to use relative position embedding", | |
) | |
parser.add_argument( | |
"--num-buckets", | |
type=int, | |
default=320, | |
help="num of buckets for relative position embedding", | |
) | |
parser.add_argument( | |
"--max-distance", | |
type=int, | |
default=1280, | |
help="max distance for relative position embedding", | |
) | |
parser.add_argument( | |
"--encoder-max-relative-position", | |
type=int, | |
help="max distance for relative position embedding in encoder", | |
) | |
parser.add_argument( | |
"--decoder-max-relative-position", | |
type=int, | |
help="max distance for relative position embedding in decoder", | |
) | |
# hubert feature extractor | |
parser.add_argument( | |
"--conv-feature-layers", | |
type=str, | |
help= "string describing convolutional feature extraction " | |
"layers in form of a python list that contains " | |
"[(dim, kernel_size, stride), ...]", | |
) | |
parser.add_argument( | |
"--conv-bias", | |
action='store_true', | |
help="include bias in conv encoder", | |
) | |
parser.add_argument( | |
"--extractor-mode", | |
choices=["default", "layer_norm"], | |
help="mode for feature extractor. default has a single group " | |
"norm with d groups in the first conv block, whereas layer_norm " | |
"has layer norms in every block (meant to use with normalize=True)" | |
) | |
# others | |
parser.add_argument( | |
"--bert-init", | |
action='store_true', | |
help="initilize as bert", | |
) | |
parser.add_argument( | |
"--unb-enc-layer", | |
type=int, | |
default=-1, | |
help="which layer's output is used as the input of decoder", | |
) | |
# Encoder, Decoder | |
def build_encoder(cls, args, dictionary=None, embed_tokens=None): | |
return TransformerEncoder(args, dictionary, embed_tokens) | |
def build_decoder(cls, args): | |
return TransformerDecoder(args) | |
# Encoder Prenet | |
def build_text_encoder_prenet(cls, embed_tokens, args): | |
return TextEncoderPrenet(embed_tokens, args) | |
def build_speech_encoder_prenet(cls, args): | |
return SpeechEncoderPrenet(args) | |
# Decoder Prenet | |
def build_text_decoder_prenet(cls, embed_tokens, args): | |
return TextDecoderPrenet(embed_tokens, args) | |
def build_speech_decoder_prenet(cls, odim, args): | |
return SpeechDecoderPrenet(odim, args) | |
# Decoder Postnet | |
def build_text_decoder_postnet(cls, embed_tokens, dictionary, args): | |
return TextDecoderPostnet(embed_tokens, dictionary, args) | |
def build_speaker_decoder_postnet(cls, embed_dim, class_num, args): | |
return SpeakerDecoderPostnet(embed_dim, class_num, args) | |
def build_speech_decoder_postnet(cls, odim, args): | |
return SpeechDecoderPostnet(odim, args) | |
def build_speech_encoder_postnet(cls, dictionaries, args): | |
return SpeechEncoderPostnet(dictionaries, args) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure all arguments are present in older models | |
base_architecture(args) | |
def build_embedding(dictionary, embed_dim, max_num_embeddings=None): | |
num_embeddings = len(dictionary) | |
if max_num_embeddings is not None and isinstance(max_num_embeddings, int): | |
num_embeddings = min(num_embeddings, max_num_embeddings) | |
padding_idx = dictionary.pad() | |
return Embedding(num_embeddings, embed_dim, padding_idx) | |
if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet: | |
max_num_embeddings = 3 # <pad> at index 2 | |
else: | |
max_num_embeddings = None | |
text_decoder_embed_tokens = build_embedding( | |
task.dicts["text"], args.decoder_embed_dim, max_num_embeddings | |
) | |
if args.share_input_output_embed: | |
text_encoder_embed_tokens = text_decoder_embed_tokens | |
else: | |
text_encoder_embed_tokens = build_embedding( | |
task.dicts["text"], args.encoder_embed_dim | |
) | |
speech_odim = args.speech_odim | |
if "text" in task.dicts: | |
encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens) | |
else: | |
encoder = cls.build_encoder(args) | |
decoder = cls.build_decoder(args) | |
text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args) | |
speech_encoder_prenet = cls.build_speech_encoder_prenet(args) | |
text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args) | |
if getattr(args, "sid_pooling_layer", None) == "decoder-las": | |
speech_decoder_prenet = cls.build_speech_encoder_prenet(args) | |
else: | |
speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args) | |
text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args) | |
speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args) | |
if getattr(args, "sid_t5_postnet", False): | |
speaker_decoder_postnet = None | |
else: | |
if task.t5_task == "s2c": | |
speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args) | |
else: | |
speaker_decoder_postnet = None | |
if "hubert" in task.dicts: | |
speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args) | |
else: | |
speech_encoder_postnet = None | |
return cls( | |
args, | |
encoder, decoder, | |
text_encoder_prenet, speech_encoder_prenet, | |
text_decoder_prenet, speech_decoder_prenet, | |
text_decoder_postnet, speech_decoder_postnet, | |
speaker_decoder_postnet, speech_encoder_postnet, | |
) | |
def get_normalized_probs( | |
self, | |
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
log_probs: bool, | |
sample: Optional[Dict[str, Tensor]] = None, | |
): | |
# net_output['encoder_out'] is a (B, T, D) tensor | |
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) | |
lprobs.batch_first = True | |
return lprobs | |
def get_normalized_probs_for_ctc(self, net_output, log_probs): | |
"""Get normalized probabilities (or log probs) from a net's output.""" | |
logits = net_output["encoder_out_for_ctc"][0] | |
if log_probs: | |
return utils.log_softmax(logits.float(), dim=-1) | |
else: | |
return utils.softmax(logits.float(), dim=-1) | |
def get_logits(self, net_output, is_masked=True): | |
if is_masked: | |
logits_list = net_output["logit_m_list"] | |
else: | |
logits_list = net_output["logit_u_list"] | |
logits_list = [x.float() for x in logits_list if x is not None] | |
return logits_list | |
def get_targets(self, sample, net_output, is_masked=True): | |
if "logit_m_list" in net_output: | |
logits_list = self.get_logits(net_output, is_masked) | |
targets_list = [ | |
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list | |
] | |
return targets_list | |
else: | |
return sample["target"] | |
def get_extra_losses(self, net_output): | |
extra_losses = [] | |
names = [] | |
if "features_pen" in net_output: | |
extra_losses.append(net_output["features_pen"]) | |
names.append("features_pen") | |
if "prob_perplexity" in net_output: | |
extra_losses.append( | |
(net_output["num_vars"] - net_output["prob_perplexity"]) | |
/ net_output["num_vars"] | |
) | |
names.append("prob_perplexity") | |
return extra_losses, names | |
def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True): | |
""" | |
The forward method inherited from the base class has a **kwargs | |
argument in its input, which is not supported in torchscript. This | |
method overwrites the forward method definition without **kwargs. | |
""" | |
assert source is not None or src_tokens is not None | |
# padding_mask is not none only when input is waveform | |
if source is None and padding_mask is None and not feature_only: | |
input_type = 'text' | |
else: | |
input_type = 'speech' | |
if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2: | |
output_type = 'text' | |
codebook_out = {} | |
else: | |
output_type = 'speech' | |
if task_name is not None and task_name == "s2c": | |
if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False): | |
sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num) | |
else: | |
sid_target = None | |
target_list = None | |
# Encoder Prenet | |
if input_type == 'text': | |
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens) | |
else: | |
if target_list is not None: | |
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask) | |
encoder_input, features_pen, mask_indices, target_list = encoder_input | |
else: | |
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training) | |
# shuffle a batch of inputs of encoder | |
if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False): | |
shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device) | |
encoder_input = torch.index_select(encoder_input, 1, shuffle_index) | |
encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index) | |
if getattr(self.args, "sid_encoder_cls", None) == "encoder": | |
prev_output_tokens = torch.zeros_like(prev_output_tokens) | |
encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask) | |
# Encoder: T x B x C | |
encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer) | |
if task_name is not None and task_name == 'speech_pretrain' and feature_only: | |
return encoder_output["encoder_out"][0].transpose(0, 1) | |
if task_name is not None and task_name == 's2c': | |
if self.args.sid_pooling_layer == "encoder": | |
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None | |
elif self.args.sid_pooling_layer == "encoder-cls": | |
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None | |
elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False): | |
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None | |
if target_list is not None: | |
hubert_results = self.hubert_layer( | |
encoder_output["encoder_out"][0].transpose(0, 1), | |
encoder_padding_mask, | |
mask_indices, | |
target_list | |
) | |
hubert_results['features_pen'] = features_pen | |
if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None: | |
# Change the encoder output to decoder input once set unb-enc-layer | |
encoder_output["encoder_out"] = encoder_output["decoder_input"] | |
if self.use_codebook: | |
q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1)) | |
# q["x"]: B x T x C | |
# Sample indexs according to the codebook prob | |
random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)] | |
# Make weight for q | |
q_w = q["x"].new_zeros(q["x"].size(1)) | |
q_w[random_idx] = 1.0 | |
# Combine quantized codes and encoder output | |
encoder_output["encoder_out"][0] = ( | |
q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1) | |
).transpose(0, 1) | |
# encoder_output["encoder_out"][0] = q["x"].transpose(0, 1) | |
if output_type == 'speech': | |
hubert_results["prob_perplexity"] = q["prob_perplexity"] | |
hubert_results["code_perplexity"] = q["code_perplexity"] | |
hubert_results["num_vars"] = q["num_vars"] | |
hubert_results["temp"] = q["temp"] | |
elif output_type == 'text': | |
codebook_out["prob_perplexity"] = q["prob_perplexity"] | |
codebook_out["code_perplexity"] = q["code_perplexity"] | |
codebook_out["num_vars"] = q["num_vars"] | |
codebook_out["temp"] = q["temp"] | |
if only_hubert and target_list is not None: | |
return hubert_results, None | |
if only_ctc and task_name is not None and task_name == "s2t": | |
return None, encoder_output | |
elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None: | |
return encoder_output | |
# Decoder Prenet | |
if output_type == 'text': | |
# _ is the incremental state | |
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens) | |
if task_name is not None and task_name == 's2c': | |
prev_output_tokens = torch.zeros_like(prev_output_tokens) | |
else: | |
# integrate speaker embedding | |
if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None: | |
# Decoder Prenet | |
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs) | |
else: | |
if self.spk_embed_dim is not None: | |
encoder_output["encoder_out"] = [self._integrate_with_spk_embed( | |
encoder_output["encoder_out"][0].transpose(0, 1), spkembs | |
).transpose(0, 1)] | |
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths) | |
# BART Sequence Classification: cat <pad> + feature before decoder | |
if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las": | |
decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths) | |
prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False) | |
# SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder | |
if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source": | |
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths) | |
# Decoder | |
decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output, | |
full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False), | |
alignment_layer=(-1 if target_list is None and output_type == 'speech' else None)) | |
# Decoder Postnet | |
if task_name is not None and task_name == 's2c': | |
if not getattr(self.args, "sid_t5_postnet", False): | |
if self.args.sid_pooling_layer == "decoder": | |
return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None | |
elif self.args.sid_pooling_layer == "decoder-las": | |
indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64) | |
indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2)) | |
return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None | |
else: | |
return (self.text_decoder_postnet(decoder_output), None), encoder_output | |
# SE predict: masking, target, delta. Ensure reduction factor 1 | |
if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None: | |
assert self.reduction_factor == 1, f"{self.reduction_factor} != 1" | |
before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output) | |
se_predict = getattr(self.args, "se_predict") | |
if se_predict == "masking": | |
before_outs = torch.sigmoid(before_outs) * src_tokens | |
after_outs = torch.sigmoid(after_outs) * src_tokens | |
return before_outs, after_outs, logits, extra['attn'][0] | |
elif se_predict == "target": | |
return before_outs, after_outs, logits, extra['attn'][0] | |
elif se_predict == "delta": | |
before_outs = before_outs - src_tokens | |
after_outs = after_outs - src_tokens | |
return before_outs, after_outs, logits, extra['attn'][0] | |
else: | |
raise ValueError(f"{se_predict} not in [masking, target, delta]") | |
if task_name is not None and task_name == 's2t': | |
#return self.text_decoder_postnet(decoder_output), None | |
return (self.text_decoder_postnet(decoder_output), None), encoder_output | |
if output_type == 'text': | |
return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output | |
else: | |
if target_list is not None: | |
return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)) | |
else: | |
return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],) | |
def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True): | |
""" | |
encoder_input: [B, T, C] | |
encoder_padding_mask: [B, T] | |
""" | |
if hasattr(self, "text_decoder_prenet"): | |
if isinstance(pad_input, tuple): | |
repeat_cls_vector, repeat_cls_mask = pad_input | |
else: | |
repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input) | |
if encoder_padding_mask is not None: | |
bsz = encoder_input.size(0) | |
tsz = encoder_input.size(1) | |
encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0 | |
if repeat_cls_mask is None: | |
mask_size = (encoder_padding_mask.size(0), 1) | |
mask_type = encoder_padding_mask.dtype | |
repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0 | |
ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1) | |
if cls_first: | |
ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1) | |
else: | |
ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1) | |
mask_size = (encoder_padding_mask.size(0), 1) | |
mask_type = encoder_padding_mask.dtype | |
repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0 | |
encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1) | |
indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1) | |
indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0) | |
ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \ | |
+ repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2) | |
return ret_encoder_input, ret_encoder_padding_mask | |
def _integrate_with_spk_embed(self, hs, spembs): | |
"""Integrate speaker embedding with hidden states. | |
Args: | |
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). | |
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). | |
Returns: | |
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) | |
""" | |
if self.spk_embed_integration_type == "add": | |
# apply projection and then add to hidden states | |
spembs = self.projection(F.normalize(spembs)) | |
hs = hs + spembs.unsqueeze(1) | |
elif self.spk_embed_integration_type == "concat": | |
# concat hidden states with spk embeds and then apply projection | |
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) | |
hs = self.projection(torch.cat([hs, spembs], dim=-1)) | |
else: | |
raise NotImplementedError("support only add or concat.") | |
return hs | |
def load_state_dict( | |
self, | |
state_dict, | |
strict=True, | |
model_cfg=None, | |
args=None, | |
): | |
"""NOT STRICT Copies parameters and buffers from *state_dict* into this module and | |
its descendants. | |
Overrides the method in :class:`nn.Module`. Compared with that method | |
this additionally "upgrades" *state_dicts* from old checkpoints. | |
""" | |
# self.prune_modules(model_cfg.modules_filter) | |
model_dict_size = self.text_decoder_postnet.output_projection.out_features | |
ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0) | |
if model_dict_size != ckpt_dict_size: | |
# reset dictionary-related modules, such as embedding table and encoder ctc embed | |
logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}") | |
logger.info(f"reset model dictionary with size of {model_dict_size}") | |
removed_keys = [ | |
key for key in state_dict.keys() if any( | |
key.startswith(previ) for previ in [ | |
"encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet" | |
] | |
) | |
] | |
for key in removed_keys: | |
state_dict.pop(key, None) | |
logger.info(f"removed loaded checkpoint: {key}") | |
for m in self._modules.keys(): | |
m_state_dict = { | |
key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.") | |
} | |
if hasattr(self, m): | |
self._modules[m].load_state_dict(m_state_dict, False) | |
return self | |
def prune_modules(self, modules_filter=None): | |
"""Prune unused modules for specific tasks.""" | |
if modules_filter is None: | |
return | |
elif modules_filter == "s2c": | |
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet | |
if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las": | |
del self.speech_decoder_prenet | |
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet | |
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet | |
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet | |
if hasattr(self.encoder, "proj"): self.encoder.proj = None | |
if hasattr(self, "projection"): del self.projection | |
if hasattr(self, "quantizer"): del self.quantizer | |
if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False): | |
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module | |
if hasattr(self.decoder, "layers"): del self.decoder.layers | |
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm | |
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet | |
elif modules_filter == "s2s": | |
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet | |
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet | |
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet | |
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet | |
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet | |
if hasattr(self.encoder, "proj"): self.encoder.proj = None | |
if hasattr(self, "projection"): del self.projection | |
if hasattr(self, "quantizer"): del self.quantizer | |
elif modules_filter == "t2s": | |
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet | |
if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet | |
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet | |
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet | |
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet | |
if hasattr(self.encoder, "proj"): self.encoder.proj = None | |
if hasattr(self, "projection"): del self.projection | |
if hasattr(self, "quantizer"): del self.quantizer | |
elif modules_filter == "s3prl": | |
# remain the encoder and the pre/post net | |
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module | |
if hasattr(self.decoder, "layers"): del self.decoder.layers | |
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm | |
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet | |
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet | |
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet | |
if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet | |
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet | |
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet | |
if hasattr(self.encoder, "proj"): self.encoder.proj = None | |
if hasattr(self, "projection"): del self.projection | |
if hasattr(self, "quantizer"): del self.quantizer | |
def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]): | |
"""A TorchScript-compatible version of forward. | |
Encoders which use additional arguments may want to override | |
this method for TorchScript compatibility. | |
""" | |
if torch.jit.is_scripting(): | |
return self.forward_encoder( | |
source=net_input["source"], | |
padding_mask=net_input["padding_mask"] | |
) | |
else: | |
return self.forward_encoder_non_torchscript(net_input) | |
def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]): | |
encoder_input = { | |
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" | |
} | |
return self.forward_encoder(**encoder_input) | |
def forward_encoder(self, source, padding_mask=None): | |
# Encoder Prenet | |
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False) | |
# Encoder | |
encoder_output = self.encoder(encoder_input, encoder_padding_mask) | |
return encoder_output | |
def forward_text_encoder(self, src_tokens): | |
# Text Encoder Prenet | |
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens) | |
# Encoder | |
encoder_output = self.encoder(encoder_input, encoder_padding_mask) | |
return encoder_output | |
def forward_decoder(self, tokens, encoder_out, incremental_state): | |
# Decoder Prenet | |
prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state) | |
# Decoder | |
decoder_output, extra = self.decoder( | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out=encoder_out, | |
incremental_state=incremental_state, | |
) | |
# Decoder Postnet | |
return self.text_decoder_postnet(decoder_output), extra | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
super().set_num_updates(num_updates) | |
self.num_updates = num_updates | |
def generate_class(self, source, prev_output_tokens, **kwargs): | |
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"]) | |
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {}) | |
prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS] | |
decoder_output, extra = self.decoder( | |
prev_output_tokens, | |
tgt_mask, | |
encoder_out=encoder_out, | |
) | |
decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1)) | |
pred_class = decoder_out.argmax(1) | |
return pred_class | |
def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs): | |
assert source is not None or src_tokens is not None | |
threshold = kwargs.get("threshold", 0.5) | |
minlenratio = kwargs.get("threshold", 0.0) | |
if source is None: | |
assert src_tokens.size(0) == 1 | |
encoder_out = self.forward_text_encoder(src_tokens) | |
maxlenratio = kwargs.get("threshold", 20.0) | |
else: | |
assert source.size(0) == 1 | |
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"]) | |
maxlenratio = kwargs.get("threshold", 10.0) | |
if spkembs is not None and self.spk_embed_integration_type != "pre": | |
encoder_out["encoder_out"] = [self._integrate_with_spk_embed( | |
encoder_out["encoder_out"][0].transpose(0, 1), spkembs | |
).transpose(0, 1)] | |
spkembs = None | |
maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor) | |
minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor) | |
idx = 0 | |
ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim) | |
outs, probs = [], [] | |
# forward decoder step-by-step | |
if isinstance(self.decoder, FairseqIncrementalDecoder): | |
incremental_states = {} | |
else: | |
incremental_states = None | |
attns = [] | |
while True: | |
# update index | |
idx += 1 | |
# calculate output and stop prob at idx-th step | |
decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs) | |
z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1) | |
outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...] | |
probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...] | |
# update next inputs | |
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim) | |
attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0)) | |
# check whether to finish generation | |
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: | |
# check mininum length | |
if idx < minlen: | |
continue | |
outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L) | |
if self.speech_decoder_postnet.postnet is not None: | |
outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L) | |
outs = outs.transpose(2, 1).squeeze(0) # (L, odim) | |
probs = torch.cat(probs, dim=0) | |
attn = torch.cat(attns, dim=2) | |
break | |
if outs.size(0) == maxlen: | |
logging.warning("output length reaches maximum length") | |
return outs, probs, attn | |
def base_architecture(args): | |
# Transformer | |
args.bert_init = getattr(args, "bert_init", False) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4) | |
args.encoder_layers = getattr(args, "encoder_layers", 12) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.attention_dropout = getattr(args, "attention_dropout", args.dropout) | |
args.activation_dropout = getattr(args, "activation_dropout", args.dropout) | |
args.activation_fn = getattr(args, "activation_fn", "gelu") | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) | |
args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS) | |
args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS) | |
# Espnet related, including prenet, postnet | |
args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0) | |
args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0) | |
args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0) | |
args.use_batch_norm = getattr(args, "use_batch_norm", True) | |
args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0) | |
args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True) | |
args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True) | |
args.postnet_layers = getattr(args, "postnet_layers", 5) | |
args.postnet_chans = getattr(args, "postnet_chans", 256) | |
args.postnet_filts = getattr(args, "postnet_filts", 5) | |
args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5) | |
args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5) | |
args.dprenet_layers = getattr(args, "dprenet_layers", 2) | |
args.dprenet_units = getattr(args, "dprenet_units", 256) | |
args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0) | |
args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0) | |
args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre") | |
args.spk_embed_dim = getattr(args, "spk_embed_dim", 512) | |
args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1) | |
args.reduction_factor = getattr(args, "reduction_factor", 2) | |
args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1) | |
args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1) | |
args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", True) | |
# Convolutional subsampler | |
args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv") | |
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") | |
args.conv_channels = getattr(args, "conv_channels", 1024) | |
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.share_input_output_embed = getattr(args, "share_input_output_embed", False) | |
args.share_ctc_embed = getattr(args, "share_ctc_embed", False) | |
args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0) | |
args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0) | |
args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None) | |
## sid | |
args.sid_embed_dim = getattr(args, "sid_embed_dim", 128) | |
args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder") | |
args.softmax_scale = getattr(args, "softmax_scale", 1) | |
args.softmax_margin = getattr(args, "softmax_margin", 0) | |
args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False) | |
args.modules_filter = getattr(args, "modules_filter", None) | |
## Hubert | |
args.conv_pos = getattr(args, "conv_pos", 128) | |
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) | |
args.target_glu = getattr(args, "target_glu", False) | |
args.logit_temp = getattr(args, "logit_temp", 0.1) | |
args.final_dim = getattr(args, "final_dim", 256) | |
args.untie_final_proj = getattr(args, "untie_final_proj", True) | |
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1) | |
args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True) | |
# hubert feature extractor | |
args.extractor_mode = getattr(args, "extractor_mode", "default") | |
args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2") | |
args.conv_bias = getattr(args, "conv_bias", False) | |
# mask | |
args.hubert_mask_length = getattr(args, "hubert_mask_length", 10) | |
args.mask_prob = getattr(args, "mask_prob", 0.0) | |
args.mask_selection = getattr(args, "mask_selection", "static") | |
args.mask_other = getattr(args, "mask_other", 0) | |
args.no_mask_overlap = getattr(args, "no_mask_overlap", False) | |
args.mask_min_space = getattr(args, "mask_min_space", 1) | |
# channel mask | |
args.mask_channel_length = getattr(args, "mask_channel_length", 10) | |
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0) | |
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") | |
args.mask_channel_other = getattr(args, "mask_channel_other", 0) | |
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) | |
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) | |
# loss computation | |
args.skip_masked = getattr(args, "skip_masked", False) | |
args.skip_nomask = getattr(args, "skip_nomask", False) | |
# conv Pos | |
args.use_conv_pos = getattr(args, "use_conv_pos", False) | |
args.use_sinc_pos = getattr(args, "use_sinc_pos", False) | |
# codebook | |
args.use_codebook = getattr(args, "use_codebook", False) | |
args.latent_vars = getattr(args, "latent_vars", 100) | |
args.latent_groups = getattr(args, "latent_groups", 2) | |
args.latent_dim = getattr(args, "latent_dim", 0) | |
args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995)) | |
args.quantizer_depth = getattr(args, "quantizer_depth", 1) | |
args.quantizer_factor = getattr(args, "quantizer_factor", 3) | |
args.codebook_prob = getattr(args, "codebook_prob", 0.5) | |
# Relative pos embed | |
args.relative_position_embedding = getattr(args, "relative_position_embedding", False) | |
args.num_buckets = getattr(args, "num_buckets", 320) | |
args.max_distance = getattr(args, "max_distance", 1280) | |
args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160) | |
args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160) | |
def artst_transformer_base(args): | |
args.use_conv_pos = getattr(args, "use_conv_pos", True) | |
args.use_sinc_pos = getattr(args, "use_sinc_pos", True) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.layer_norm_first = getattr(args, "layer_norm_first", False) | |
args.relative_position_embedding = getattr(args, "relative_position_embedding", True) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.1) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05) | |
args.mask_prob = getattr(args, "mask_prob", 0.80) | |
base_architecture(args) | |
def artst_transformer_large(args): | |
args.use_conv_pos = getattr(args, "use_conv_pos", True) | |
args.use_sinc_pos = getattr(args, "use_sinc_pos", True) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) | |
args.layer_norm_first = getattr(args, "layer_norm_first", True) | |
args.relative_position_embedding = getattr(args, "relative_position_embedding", True) | |
args.dropout = getattr(args, "dropout", 0.0) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_layers = getattr(args, "encoder_layers", 24) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | |
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) | |
args.extractor_mode = getattr(args, "extractor_mode", "layer_norm") | |
args.final_dim = getattr(args, "final_dim", 768) | |
args.mask_prob = getattr(args, "mask_prob", 0.80) | |
base_architecture(args) | |
def artst_transformer_base_asr(args): | |
args.use_conv_pos = getattr(args, "use_conv_pos", True) | |
args.use_sinc_pos = getattr(args, "use_sinc_pos", True) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.layer_norm_first = getattr(args, "layer_norm_first", False) | |
args.relative_position_embedding = getattr(args, "relative_position_embedding", True) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.1) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.1) | |
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1) | |
args.mask_prob = getattr(args, "mask_prob", 0.75) | |
args.mask_selection = getattr(args, "mask_selection", "static") | |
args.mask_channel_length = getattr(args, "mask_channel_length", 64) | |
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) | |
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") | |
args.max_text_positions = getattr(args, "max_text_positions", 600) | |
base_architecture(args) | |