Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import math | |
import torch.nn as nn | |
import torch | |
import contextlib | |
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ | |
from fairseq.models.transformer import Linear #,LayerNorm | |
from fairseq.modules import ( | |
PositionalEmbedding, | |
FairseqDropout, | |
LayerNorm | |
) | |
class TextDecoderPrenet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__(self, embed_tokens, args): | |
super(TextDecoderPrenet, self).__init__() | |
self.dropout_module = FairseqDropout( | |
args.dropout, module_name=self.__class__.__name__ | |
) | |
self.decoder_layerdrop = args.decoder_layerdrop | |
self.num_updates = 0 | |
input_embed_dim = embed_tokens.embedding_dim | |
embed_dim = args.decoder_embed_dim | |
self.embed_dim = embed_dim | |
self.output_embed_dim = args.decoder_output_dim | |
self.padding_idx = embed_tokens.padding_idx | |
self.embed_tokens = embed_tokens | |
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) | |
if not args.adaptive_input and args.quant_noise_pq > 0: | |
self.quant_noise = apply_quant_noise_( | |
nn.Linear(embed_dim, embed_dim, bias=False), | |
args.quant_noise_pq, | |
args.quant_noise_pq_block_size, | |
) | |
else: | |
self.quant_noise = None | |
self.project_in_dim = ( | |
Linear(input_embed_dim, embed_dim, bias=False) | |
if embed_dim != input_embed_dim | |
else None | |
) | |
self.embed_positions = ( | |
PositionalEmbedding( | |
args.max_text_positions, | |
embed_dim, | |
self.padding_idx, | |
learned=args.decoder_learned_pos, | |
) | |
if not args.no_token_positional_embeddings | |
else None | |
) | |
export = getattr(args, "export", False) | |
if getattr(args, "layernorm_embedding", False): | |
self.layernorm_embedding = LayerNorm(embed_dim, export=export) | |
else: | |
self.layernorm_embedding = None | |
self.freeze_decoder_updates = args.freeze_decoder_updates | |
def forward(self, prev_output_tokens, incremental_state=None): | |
ft = self.freeze_decoder_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
return self._forward(prev_output_tokens, incremental_state) | |
def _forward(self, prev_output_tokens, incremental_state=None): | |
if prev_output_tokens.eq(self.padding_idx).any(): | |
x_mask = prev_output_tokens.eq(self.padding_idx) | |
else: | |
x_mask = None | |
# embed positions | |
positions = None | |
if self.embed_positions is not None: | |
positions = self.embed_positions( | |
prev_output_tokens, incremental_state=incremental_state | |
) | |
if incremental_state is not None: | |
prev_output_tokens = prev_output_tokens[:, -1:] | |
if positions is not None: | |
positions = positions[:, -1:] | |
# embed tokens and positions | |
x = self.embed_scale * self.embed_tokens(prev_output_tokens) | |
if self.quant_noise is not None: | |
x = self.quant_noise(x) | |
if self.project_in_dim is not None: | |
x = self.project_in_dim(x) | |
if positions is not None: | |
x += positions | |
if self.layernorm_embedding is not None: | |
x = self.layernorm_embedding(x) | |
x = self.dropout_module(x) | |
return x, x_mask, incremental_state | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
self.num_updates = num_updates | |