# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import logging import math import torch import contextlib from typing import List, Tuple import torch.nn as nn from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import compute_mask_indices from fairseq.modules import ( PositionalEmbedding, Fp32GroupNorm, FairseqDropout, SamePad, GradMultiply, LayerNorm, Fp32LayerNorm, TransposeLast, ) import numpy as np logger = logging.getLogger(__name__) class LinearLayer(nn.Module): def __init__(self, idim, odom, dropout=0): super(LinearLayer, self).__init__() self.linear = nn.Sequential( nn.Linear(idim, odom), nn.LayerNorm(odom), nn.Dropout(dropout), nn.ReLU(), ) def get_out_seq_lens_tensor(self, in_seq_lens_tensor): out = in_seq_lens_tensor.clone() return out def forward(self, src_tokens, src_lengths): """ src_tokens: [B, T, C] src_lengths: [B] """ x = self.linear(src_tokens) x = x.transpose(0, 1).contiguous() # -> T x B x C return x, src_lengths class SpeechEncoderPrenet(nn.Module): """ Args: in_channels (int): the number of input channels mid_channels (int): the number of intermediate channels out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ def __init__(self, args): super(SpeechEncoderPrenet, self).__init__() self.dropout_module = FairseqDropout( p=args.dropout, module_name=self.__class__.__name__ ) self.embed_scale = math.sqrt(args.encoder_embed_dim) if args.no_scale_embedding: self.embed_scale = 1.0 self.padding_idx = 1 self.freeze_encoder_updates = args.freeze_encoder_updates self.num_updates = 0 assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet feature_enc_layers = eval(args.conv_feature_layers) # noqa self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=args.extractor_mode, conv_bias=args.conv_bias, ) feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) self.feat2tar_ratio = ( args.label_rates * feature_ds_rate / args.sample_rate ) self.post_extract_proj = ( nn.Linear(self.embed, args.encoder_embed_dim) if self.embed != args.encoder_embed_dim else None ) self.use_conv_pos = args.use_conv_pos self.use_sinc_pos = args.use_sinc_pos self.use_abs_pos = getattr(args, "use_abs_pos", False) self.feature_grad_mult = args.feature_grad_mult if self.use_conv_pos: self.layer_norm = LayerNorm(self.embed) self.pos_conv = nn.Conv1d( args.encoder_embed_dim, args.encoder_embed_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim)) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}" if self.use_sinc_pos: self.embed_positions = PositionalEmbedding( args.max_speech_positions, args.encoder_embed_dim, self.padding_idx ) if self.use_abs_pos: self.embed_positions = PositionalEmbedding( args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True ) # Hubert self.mask_prob = args.mask_prob self.mask_selection = args.mask_selection self.mask_other = args.mask_other self.hubert_mask_length = args.hubert_mask_length self.no_mask_overlap = args.no_mask_overlap self.mask_min_space = args.mask_min_space self.mask_channel_prob = args.mask_channel_prob self.mask_channel_selection = args.mask_channel_selection self.mask_channel_other = args.mask_channel_other self.mask_channel_length = args.mask_channel_length self.no_mask_channel_overlap = args.no_mask_channel_overlap self.mask_channel_min_space = args.mask_channel_min_space self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_() ) def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): ft = self.freeze_encoder_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask) def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): if self.feature_grad_mult > 0: x = self.feature_extractor(src_tokens) x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] if self.feature_grad_mult != 1.0: x = GradMultiply.apply(x, self.feature_grad_mult) else: with torch.no_grad(): x = self.feature_extractor(src_tokens) x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] x = x.transpose(0, 1) # [batch, length, hidden_size] encoder_padding_mask = padding_mask x = x.transpose(1, 2) # [batch, hidden_size, length] if target_list is not None: x, target_list = self.forward_targets(x, target_list) features_pen = x.float().pow(2).mean() x = x.transpose(1, 2) # [batch, length, hidden_size] x = self.layer_norm(x) encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask) if self.post_extract_proj is not None: x = self.post_extract_proj(x) x = self.dropout_module(x) if mask: x, mask_indices = self.apply_hubert_mask( x, encoder_padding_mask ) else: x = x mask_indices = None if self.use_conv_pos: positions = self.pos_conv(x.transpose(1, 2)) positions = positions.transpose(1, 2) #else: # positions = self.embed_positions(encoder_padding_mask) x = x + positions if self.use_sinc_pos: positions = self.embed_positions(encoder_padding_mask) x = x + positions # x = self.dropout_module(x) if require_feat_pen: return (x, features_pen, mask_indices, target_list), encoder_padding_mask else: # For consistence with encoder return x, encoder_padding_mask def forward_targets( self, features: torch.Tensor, target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Trim features to ensure labels exist and then get aligned labels feat_tsz = features.size(2) targ_tsz = min([t.size(1) for t in target_list]) if self.feat2tar_ratio * feat_tsz > targ_tsz: feat_tsz = int(targ_tsz / self.feat2tar_ratio) features = features[..., :feat_tsz] target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio target_list = [t[:, target_inds.long()] for t in target_list] return features, target_list def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view( padding_mask.size(0), features.size(1), -1 ) padding_mask = padding_mask.all(-1) return padding_mask def get_src_lengths(self, src_lengths): return self.feature_extractor.get_out_seq_lens_tensor(src_lengths) def apply_hubert_mask(self, x, padding_mask): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.hubert_mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) .unsqueeze(1) .expand(-1, T, -1) ) x[mask_channel_indices] = 0 return x, mask_indices def set_num_updates(self, num_updates): """Set the number of parameters updates.""" self.num_updates = num_updates class ConvFeatureExtractionModel(nn.Module): def __init__( self, conv_layers: List[Tuple[int, int, int]], dropout: float = 0.0, mode: str = "default", conv_bias: bool = False, ): super().__init__() assert mode in {"default", "layer_norm"} def block( n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False, ): def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) nn.init.kaiming_normal_(conv.weight) return conv assert ( is_layer_norm and is_group_norm ) == False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), nn.Sequential( TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast(), ), nn.GELU(), ) elif is_group_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU(), ) else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) in_d = 1 self.conv_layers = nn.ModuleList() self.conv_layers_infos = conv_layers for i, cl in enumerate(conv_layers): assert len(cl) == 3, "invalid conv definition: " + str(cl) (dim, k, stride) = cl self.conv_layers.append( block( in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias, ) ) in_d = dim def forward(self, x): # BxT -> BxCxT x = x.unsqueeze(1) for conv in self.conv_layers: x = conv(x) return x def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i): """Returns the out_seq_lens_nonmask 0/1 tensor after a layer. Args: in_seq_lens_tensor (LongTensor): length Returns: LongTensor: length """ out_lengths = in_seq_lens_tensor.clone() out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() out_nonmask = (~lengths_to_padding_mask(out_lengths)).float() return out_nonmask, out_lengths def get_out_seq_lens_tensor(self, in_seq_lens_tensor): out = in_seq_lens_tensor.clone() for i in range(len(self.conv_layers)): out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() return out