# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import torch.nn as nn import torch import contextlib from fairseq import utils from fairseq.modules import ( AdaptiveSoftmax, ) class TextDecoderPostnet(nn.Module): """ Args: in_channels (int): the number of input channels mid_channels (int): the number of intermediate channels out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ def __init__(self, embed_tokens, dictionary, args, output_projection=None,): super(TextDecoderPostnet, self).__init__() self.output_embed_dim = args.decoder_output_dim self.output_projection = output_projection self.adaptive_softmax = None self.share_input_output_embed = args.share_input_output_embed if self.output_projection is None: self.build_output_projection(args, dictionary, embed_tokens) self.freeze_decoder_updates = args.freeze_decoder_updates self.num_updates = 0 def output_layer(self, features): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary return self.output_projection(features) else: return features def build_output_projection(self, args, dictionary, embed_tokens): if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( embed_tokens.weight.shape[1], embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = embed_tokens.weight else: self.output_projection = nn.Linear( self.output_embed_dim, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) # num_base_layers = getattr(args, "base_layers", 0) # for i in range(num_base_layers): # self.layers.insert( # ((i + 1) * args.decoder_layers) // (num_base_layers + 1), # BaseLayer(args), # ) def forward(self, x): ft = self.freeze_decoder_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): return self._forward(x) def _forward(self, x): # embed positions x = self.output_layer(x) return x def set_num_updates(self, num_updates): """Set the number of parameters updates.""" self.num_updates = num_updates