Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import contextlib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet | |
| from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding | |
| from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding | |
| from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | |
| class SpeechDecoderPrenet(nn.Module): | |
| """ | |
| Args: | |
| in_channels (int): the number of input channels | |
| mid_channels (int): the number of intermediate channels | |
| out_channels (int): the number of output channels | |
| kernel_sizes (List[int]): the kernel size for each convolutional layer | |
| """ | |
| def __init__( | |
| self, | |
| odim, | |
| args, | |
| ): | |
| super(SpeechDecoderPrenet, self).__init__() | |
| # define decoder prenet | |
| if args.dprenet_layers != 0: | |
| # decoder prenet | |
| decoder_input_layer = torch.nn.Sequential( | |
| TacotronDecoderPrenet( | |
| idim=odim, | |
| n_layers=args.dprenet_layers, | |
| n_units=args.dprenet_units, | |
| dropout_rate=args.dprenet_dropout_rate, | |
| ), | |
| torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim), | |
| ) | |
| else: | |
| decoder_input_layer = "linear" | |
| pos_enc_class = ( | |
| ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding | |
| ) | |
| if decoder_input_layer == "linear": | |
| self.decoder_prenet = torch.nn.Sequential( | |
| torch.nn.Linear(odim, args.decoder_embed_dim), | |
| torch.nn.LayerNorm(args.decoder_embed_dim), | |
| torch.nn.Dropout(args.transformer_dec_dropout_rate), | |
| torch.nn.ReLU(), | |
| pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate), | |
| ) | |
| elif isinstance(decoder_input_layer, torch.nn.Module): | |
| self.decoder_prenet = torch.nn.Sequential( | |
| decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions) | |
| ) | |
| if args.spk_embed_integration_type == 'pre': | |
| self.spkembs_layer = torch.nn.Sequential( | |
| torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU() | |
| ) | |
| self.num_updates = 0 | |
| self.freeze_decoder_updates = args.freeze_decoder_updates | |
| def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None): | |
| ft = self.freeze_decoder_updates <= self.num_updates | |
| with torch.no_grad() if not ft else contextlib.ExitStack(): | |
| prev_output_tokens = self.decoder_prenet(prev_output_tokens) | |
| if spkembs is not None: | |
| spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1) | |
| prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1)) | |
| if tgt_lengths_in is not None: | |
| tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1)) | |
| else: | |
| tgt_frames_mask = None | |
| return prev_output_tokens, tgt_frames_mask | |
| def _source_mask(self, ilens): | |
| """Make masks for self-attention. | |
| Args: | |
| ilens (LongTensor or List): Batch of lengths (B,). | |
| Returns: | |
| Tensor: Mask tensor for self-attention. | |
| dtype=torch.uint8 in PyTorch 1.2- | |
| dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
| Examples: | |
| >>> ilens = [5, 3] | |
| >>> self._source_mask(ilens) | |
| tensor([[[1, 1, 1, 1, 1], | |
| [[1, 1, 1, 0, 0]]], dtype=torch.uint8) | |
| """ | |
| x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) | |
| return x_masks.unsqueeze(-2) | |
| def set_num_updates(self, num_updates): | |
| """Set the number of parameters updates.""" | |
| self.num_updates = num_updates | |