Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import contextlib | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet | |
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding | |
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding | |
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | |
class SpeechDecoderPrenet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__( | |
self, | |
odim, | |
args, | |
): | |
super(SpeechDecoderPrenet, self).__init__() | |
# define decoder prenet | |
if args.dprenet_layers != 0: | |
# decoder prenet | |
decoder_input_layer = torch.nn.Sequential( | |
TacotronDecoderPrenet( | |
idim=odim, | |
n_layers=args.dprenet_layers, | |
n_units=args.dprenet_units, | |
dropout_rate=args.dprenet_dropout_rate, | |
), | |
torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim), | |
) | |
else: | |
decoder_input_layer = "linear" | |
pos_enc_class = ( | |
ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding | |
) | |
if decoder_input_layer == "linear": | |
self.decoder_prenet = torch.nn.Sequential( | |
torch.nn.Linear(odim, args.decoder_embed_dim), | |
torch.nn.LayerNorm(args.decoder_embed_dim), | |
torch.nn.Dropout(args.transformer_dec_dropout_rate), | |
torch.nn.ReLU(), | |
pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate), | |
) | |
elif isinstance(decoder_input_layer, torch.nn.Module): | |
self.decoder_prenet = torch.nn.Sequential( | |
decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions) | |
) | |
if args.spk_embed_integration_type == 'pre': | |
self.spkembs_layer = torch.nn.Sequential( | |
torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU() | |
) | |
self.num_updates = 0 | |
self.freeze_decoder_updates = args.freeze_decoder_updates | |
def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None): | |
ft = self.freeze_decoder_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
prev_output_tokens = self.decoder_prenet(prev_output_tokens) | |
if spkembs is not None: | |
spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1) | |
prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1)) | |
if tgt_lengths_in is not None: | |
tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1)) | |
else: | |
tgt_frames_mask = None | |
return prev_output_tokens, tgt_frames_mask | |
def _source_mask(self, ilens): | |
"""Make masks for self-attention. | |
Args: | |
ilens (LongTensor or List): Batch of lengths (B,). | |
Returns: | |
Tensor: Mask tensor for self-attention. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
>>> ilens = [5, 3] | |
>>> self._source_mask(ilens) | |
tensor([[[1, 1, 1, 1, 1], | |
[[1, 1, 1, 0, 0]]], dtype=torch.uint8) | |
""" | |
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) | |
return x_masks.unsqueeze(-2) | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
self.num_updates = num_updates | |