# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import math import torch.nn as nn import torch import contextlib from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.models.transformer import Linear #,LayerNorm from fairseq.modules import ( PositionalEmbedding, FairseqDropout, LayerNorm ) class TextDecoderPrenet(nn.Module): """ Args: in_channels (int): the number of input channels mid_channels (int): the number of intermediate channels out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ def __init__(self, embed_tokens, args): super(TextDecoderPrenet, self).__init__() self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.decoder_layerdrop = args.decoder_layerdrop self.num_updates = 0 input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None self.project_in_dim = ( Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) self.embed_positions = ( PositionalEmbedding( args.max_text_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) export = getattr(args, "export", False) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim, export=export) else: self.layernorm_embedding = None self.freeze_decoder_updates = args.freeze_decoder_updates def forward(self, prev_output_tokens, incremental_state=None): ft = self.freeze_decoder_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): return self._forward(prev_output_tokens, incremental_state) def _forward(self, prev_output_tokens, incremental_state=None): if prev_output_tokens.eq(self.padding_idx).any(): x_mask = prev_output_tokens.eq(self.padding_idx) else: x_mask = None # embed positions positions = None if self.embed_positions is not None: positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) return x, x_mask, incremental_state def set_num_updates(self, num_updates): """Set the number of parameters updates.""" self.num_updates = num_updates