artst-tts-demo / artst /models /modules /text_decoder_prenet.py
herwoww's picture
first upload
1547a56
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import math
import torch.nn as nn
import torch
import contextlib
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from fairseq.models.transformer import Linear #,LayerNorm
from fairseq.modules import (
PositionalEmbedding,
FairseqDropout,
LayerNorm
)
class TextDecoderPrenet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, embed_tokens, args):
super(TextDecoderPrenet, self).__init__()
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.num_updates = 0
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_text_positions,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
export = getattr(args, "export", False)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim, export=export)
else:
self.layernorm_embedding = None
self.freeze_decoder_updates = args.freeze_decoder_updates
def forward(self, prev_output_tokens, incremental_state=None):
ft = self.freeze_decoder_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
return self._forward(prev_output_tokens, incremental_state)
def _forward(self, prev_output_tokens, incremental_state=None):
if prev_output_tokens.eq(self.padding_idx).any():
x_mask = prev_output_tokens.eq(self.padding_idx)
else:
x_mask = None
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, x_mask, incremental_state
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self.num_updates = num_updates