"""Building blocks for speech SSL models supporting pruning. Originally from: https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py """ import math from collections import defaultdict from typing import List, Optional, Tuple import torch from torch import Tensor, nn from torch.nn import Module from .hardconcrete import HardConcrete from .pruning_utils import ( prune_conv1d_layer, prune_layer_norm, prune_linear_layer, ) def _init_transformer_params(module): """ Initialize the weights of Transformer module in Wav2Vec2/HuBERT. If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. If ``padding_idx`` is not None, set the weight of padding to 0. Note: Ths method corresponds to `init_bert_params `__ in the original ``fairseq`` implementation. """ def normal_(data): data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class LayerNorm(nn.LayerNorm): """Layer norm with transpose""" def forward(self, input: Tensor) -> Tensor: x = input.transpose(-2, -1) x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.transpose(-2, -1) return x class ConvLayerBlock(Module): """Convolution unit of FeatureExtractor""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, bias: bool, layer_norm: Optional[Module], prune_conv_channels: bool = False, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.layer_norm = layer_norm self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=bias, ) if prune_conv_channels: self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) else: self.hard_concrete = None def forward( self, x: Tensor, length: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: x (Tensor): Shape: ``[batch, in_channels, in_frame]``. length (Tensor or None, optional): Shape ``[batch, ]``. Returns: Tensor: Shape ``[batch, out_channels, out_frames]``. Optional[Tensor]: Shape ``[batch, ]``. """ x = self.conv(x) if self.layer_norm is not None: x = self.layer_norm(x) x = nn.functional.gelu(x) if self.hard_concrete is not None: channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) x = x * channel_mask.unsqueeze(-1) if length is not None: length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 # When input length is 0, the resulting length can be negative. So fix it here. length = torch.max(torch.zeros_like(length), length) return x, length def get_num_params_and_out_channels(self, in_channels): if self.hard_concrete is not None: out_channels = self.hard_concrete.l0_norm() else: out_channels = self.conv.out_channels num_params = in_channels * out_channels * self.kernel_size if self.conv.bias is not None: num_params += out_channels if self.layer_norm is not None: num_params += out_channels * 2 return num_params, out_channels class FeatureExtractor(Module): """Extract features from audio Args: conv_layers (nn.ModuleList): convolution layers """ def __init__( self, conv_layers: nn.ModuleList, ): super().__init__() self.conv_layers = conv_layers # NOTE: a dummy weight used to save the soft mask of the last conv layer self.dummy_weight = nn.Parameter( torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), requires_grad=False ) def forward( self, x: Tensor, length: Optional[Tensor], ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: x (Tensor): Input Tensor representing a batch of audio, shape: ``[batch, time]``. length (Tensor or None, optional): Valid length of each input sample. shape: ``[batch, ]``. Returns: Tensor: The resulting feature, shape: ``[batch, frame, feature]`` Optional[Tensor]: Valid length of each output sample. shape: ``[batch, ]``. """ if x.ndim != 2: raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") x = x.unsqueeze(1) # (batch, channel==1, frame) for layer in self.conv_layers: x, length = layer(x, length) # (batch, feature, frame) x = x.transpose(1, 2) # (batch, frame, feature) x = x * self.dummy_weight return x, length def get_num_params_and_final_out_channels(self): in_channels = 1 num_params = 0 for layer in self.conv_layers: layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) num_params += layer_params num_params += in_channels # dummy weight return num_params, in_channels def prune(self): """"Prune conv layers and dummy weight based on hardconcrete parameters. This is an in-place operation. """ new_config = [] # [(output_channel, kernel_size, stride), ...] for idx, layer in enumerate(self.conv_layers): if layer.hard_concrete is not None: assert not layer.hard_concrete.training mask = layer.hard_concrete() # (out_features,) index = mask.nonzero().squeeze(-1) # 2D -> 1D assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" new_config.append( (len(index), layer.kernel_size, layer.stride) ) # prune the current layer prune_conv1d_layer(layer.conv, index, "output") if layer.layer_norm is not None: prune_layer_norm(layer.layer_norm, index) # prune the next layer if idx == len(self.conv_layers) - 1: self.dummy_weight.data *= mask self.dummy_weight = nn.Parameter( self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False ) else: self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") layer.hard_concrete = None else: new_config.append( (layer.conv.out_channels, layer.kernel_size, layer.stride) ) index = torch.arange(layer.conv.out_channels, dtype=torch.long) return new_config, index class FeatureProjection(Module): """Layer that connects FeatureExtractor and Encoder Projects features to encoder dimension. Args: in_features (int): Input feature dim. out_features (int): Output feature dim. dropout (float): Dropout probability. """ def __init__( self, in_features: int, out_features: int, dropout: float, ): super().__init__() self.layer_norm = nn.LayerNorm(in_features) self.projection = nn.Linear( in_features, out_features, ) self.dropout = nn.Dropout(dropout) def forward(self, x): """ Args: x (Tensor): Feature Tensor. shape: ``[batch, frame, in_feature]`` Returns: Tensor: Projected features. ``[batch, frame, out_feature]``. """ x = self.layer_norm(x) x = self.projection(x) x = self.dropout(x) return x def get_num_params(self, in_features): return in_features * 2 + (in_features + 1) * self.projection.out_features class ConvolutionalPositionalEmbedding(Module): """Positional embedding which is placed at the beginning of Transformer. Args: embed_dim (int): Feature dimension of the input Tensor. kernel_size (int): The number of frames to be use. groups (int): The number of groups in feature dimensions. """ def __init__( self, embed_dim: int, kernel_size: int, groups: int, ): super().__init__() self.embed_dim = embed_dim self.kernel_size = kernel_size self.conv = nn.Conv1d( in_channels=embed_dim, out_channels=embed_dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=groups, ) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 def __prepare_scriptable__(self): for hook in self.conv._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": torch.nn.utils.remove_weight_norm(self.conv) return self def forward(self, x): """ Args: x (Tensor): shape ``[batch, frame, feature]``. Returns: Tensor: The resulting feature. Shape ``[batch, frame, feature]``. """ x = x.transpose(-2, -1) x = self.conv(x) if self.num_remove > 0: x = x[..., : -self.num_remove] x = torch.nn.functional.gelu(x) x = x.transpose(-2, -1) return x class SelfAttention(Module): """Multihead Self Attention module Args: embed_dim (int): Total dimension of the model. num_heads (int): The number of heads. dropout (float, optional): Dropout probability on attn_output_weights. Default: ``0.0`` """ def __init__( self, embed_dim: int, num_heads: int, head_dim: int, dropout: float = 0.0, prune_heads: bool = False, # whether to prune attention heads prune_layer: bool = False, # whether to prune entire attention layers ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = head_dim self.dropout = torch.nn.Dropout(dropout) self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) if prune_heads: self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) else: self.hard_concrete_for_heads = None if prune_layer: self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) else: self.hard_concrete_for_layer = None def forward( self, x: Tensor, attention_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. attention_mask (Tensor or ``None``, optional): shape: ``[batch_size, 1, sequence_length, sequence_length]`` position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. Returns: (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility with :py:class:`WavLMSelAttention`). Attention output shape: ``[batch, sequence_length, embed_dim]``. """ if x.ndim != 3 or x.shape[2] != self.embed_dim: raise ValueError( f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." ) batch_size, length, embed_dim = x.size() shape = (batch_size, length, self.num_heads, self.head_dim) q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd # scale down q to avoid value overflow. weights = (self.scaling * q) @ k # B, nH, L, L if attention_mask is not None: weights += attention_mask # subtracting a constant value from the tensor won't change the output of softmax. # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 weights = weights - weights.max(dim=-1, keepdim=True)[0] weights = torch.nn.functional.softmax(weights, dim=-1) weights = self.dropout(weights) output = weights @ v # B, nH, L, Hd if self.hard_concrete_for_heads is not None: head_mask = self.hard_concrete_for_heads() # (nH,) output = output * head_mask.unsqueeze(-1).unsqueeze(-1) output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) output = self.out_proj(output) if self.hard_concrete_for_layer is not None: layer_mask = self.hard_concrete_for_layer() # (1,) output = output * layer_mask return output, None # Necessary for compatibility with WavLMSelAttention def get_num_params(self): if self.hard_concrete_for_heads is not None: num_heads = self.hard_concrete_for_heads.l0_norm() else: num_heads = self.num_heads num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ + (num_heads * self.head_dim + 1) * self.embed_dim if self.hard_concrete_for_layer is not None: num_params *= self.hard_concrete_for_layer.l0_norm() return num_params def prune(self): new_config = { "use_attention": True, "num_heads": self.num_heads, } if self.hard_concrete_for_layer is not None: assert not self.hard_concrete_for_layer.training layer_mask = self.hard_concrete_for_layer() # (1,) self.out_proj.weight.data *= layer_mask self.out_proj.bias.data *= layer_mask if layer_mask == 0: new_config["use_attention"] = False self.hard_concrete_for_layer = None if self.hard_concrete_for_heads is not None: assert not self.hard_concrete_for_heads.training head_mask = self.hard_concrete_for_heads() # (num_heads,) new_config["num_heads"] = len(head_mask.nonzero()) if new_config["num_heads"] == 0: new_config["use_attention"] = False else: full_mask = head_mask.repeat_interleave(self.head_dim) full_index = full_mask.nonzero().squeeze(-1) # 1D prune_linear_layer(self.k_proj, full_index, "output") prune_linear_layer(self.v_proj, full_index, "output") prune_linear_layer(self.q_proj, full_index, "output") self.out_proj.weight.data *= full_mask prune_linear_layer(self.out_proj, full_index, "input") self.hard_concrete_for_heads = None return new_config class WavLMSelfAttention(SelfAttention): """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. Args: embed_dim (int): Total dimension of the model. num_heads (int): The number of heads. dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) """ def __init__( self, embed_dim: int, total_num_heads: int, remaining_heads: Optional[List[int]] = None, dropout: float = 0.0, bias: bool = True, has_relative_attention_bias: bool = False, num_buckets: int = 32, max_distance: int = 128, gru_rel_pos: bool = True, prune_heads: bool = False, prune_layer: bool = False, ): self.total_num_heads = total_num_heads if remaining_heads is None: self.remaining_heads = list(range(total_num_heads)) else: self.remaining_heads = remaining_heads # list of indices self.head_dim = embed_dim // total_num_heads super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) self.has_relative_attention_bias = has_relative_attention_bias self.num_buckets = num_buckets self.max_distance = max_distance if has_relative_attention_bias: self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) else: self.rel_attn_embed = None # override linear layers to customize bias self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) self.gru_rel_pos = gru_rel_pos if self.gru_rel_pos: self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) self.has_position_bias = True def compute_bias(self, query_length: int, key_length: int) -> Tensor: """Compute relative position embeddings for WavLM model. Args: query_length (int): Query position can take values between 0 and ``query_length - 1``. key_length (int): Key position can take values between 0 and ``key_length - 1``. Returns: Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings """ context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # Shape (query_length, key_length) relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]) return values def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM paper :cite:`chen2022wavlm`. Args: relative_positions (Tensor): Relative offsets between query and key positions, of shape ``(query_length, key_length)``. bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set to zero. (Default ``True``) Returns: Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. """ num_buckets = self.num_buckets max_distance = self.max_distance # Shape (query_length, key_length) relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets relative_positions = torch.abs(relative_positions) else: relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_postion_if_large = max_exact + ( torch.log(relative_positions.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) return relative_buckets def forward( self, query: Tensor, attention_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) attn_mask: Needs to be ``None``. The argument exists for compatibility with ``EncoderLayer``. (Default: ``None``) position_bias (Tensor or None, optional): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be generated in the first layer and then passed from each encoder layer to the next one. (Default: ``None``) Returns: attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. """ bsz, seq_len, embed_dim = query.size() assert embed_dim == self.embed_dim assert key_padding_mask is None # only for the first layer if self.rel_attn_embed is not None and position_bias is None: position_bias = self.compute_bias(seq_len, seq_len) position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) attn_mask_rel_pos: Optional[Tensor] = None if position_bias is not None: attn_mask_rel_pos = position_bias if self.gru_rel_pos: # Apply gating on relative position bias query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) query_layer = query_layer.permute(0, 2, 1, 3) gate_a, gate_b = torch.sigmoid( self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] attn_mask = attn_mask_rel_pos if attention_mask is not None: attn_mask = attn_mask + attention_mask if key_padding_mask is not None: attn_mask = attn_mask.masked_fill( key_padding_mask.reshape(bsz, 1, 1, seq_len), float("-inf") ) attn_output, _ = super().forward(query, attention_mask=attn_mask) return attn_output, position_bias def prune(self): new_config = { "use_attention": True, "remaining_heads": self.remaining_heads, } if self.hard_concrete_for_layer is not None: assert not self.hard_concrete_for_layer.training layer_mask = self.hard_concrete_for_layer() # (1,) self.out_proj.weight.data *= layer_mask self.out_proj.bias.data *= layer_mask if layer_mask == 0: new_config["use_attention"] = False self.hard_concrete_for_layer = None if self.hard_concrete_for_heads is not None: assert not self.hard_concrete_for_heads.training head_mask = self.hard_concrete_for_heads() # (num_heads,) new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() if len(new_config["remaining_heads"]) == 0: new_config["use_attention"] = False else: full_mask = head_mask.repeat_interleave(self.head_dim) full_index = full_mask.nonzero().squeeze(-1) # 1D prune_linear_layer(self.k_proj, full_index, "output") prune_linear_layer(self.v_proj, full_index, "output") prune_linear_layer(self.q_proj, full_index, "output") self.out_proj.weight.data *= full_mask prune_linear_layer(self.out_proj, full_index, "input") self.hard_concrete_for_heads = None return new_config class FeedForward(Module): """Layer that follows attention layer in encoder layer.""" def __init__( self, io_features: int, intermediate_features: int, intermediate_dropout: float, output_dropout: float, prune_intermediate: bool = False, prune_layer: bool = False, ): super().__init__() self.intermediate_dense = nn.Linear(io_features, intermediate_features) self.intermediate_dropout = nn.Dropout(intermediate_dropout) self.output_dense = nn.Linear(intermediate_features, io_features) self.output_dropout = nn.Dropout(output_dropout) if prune_intermediate: self.hard_concrete_for_intermediate = HardConcrete( n_in=intermediate_features, init_mean=0.5 ) else: self.hard_concrete_for_intermediate = None if prune_layer: self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) else: self.hard_concrete_for_layer = None def forward(self, x): """ Args: x (Tensor): shape: `(batch, sequence_length, io_features)` Returns: x (Tensor): shape: `(batch, sequence_length, io_features)` """ x = self.intermediate_dense(x) x = torch.nn.functional.gelu(x) x = self.intermediate_dropout(x) if self.hard_concrete_for_intermediate is not None: intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) x = x * intermediate_mask x = self.output_dense(x) x = self.output_dropout(x) if self.hard_concrete_for_layer is not None: layer_mask = self.hard_concrete_for_layer() # (1,) x = x * layer_mask return x def get_num_params(self): io_features = self.intermediate_dense.in_features if self.hard_concrete_for_intermediate is not None: intermediate_features = self.hard_concrete_for_intermediate.l0_norm() else: intermediate_features = self.intermediate_dense.out_features num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features if self.hard_concrete_for_layer is not None: num_params *= self.hard_concrete_for_layer.l0_norm() return num_params def prune(self): new_config = { "use_feed_forward": True, "ff_interm_features": self.intermediate_dense.out_features } if self.hard_concrete_for_layer is not None: assert not self.hard_concrete_for_layer.training layer_mask = self.hard_concrete_for_layer() self.output_dense.weight.data *= layer_mask self.output_dense.bias.data *= layer_mask if layer_mask == 0: new_config["use_feed_forward"] = False self.hard_concrete_for_layer = None if self.hard_concrete_for_intermediate is not None: assert not self.hard_concrete_for_intermediate.training interm_mask = self.hard_concrete_for_intermediate() interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 new_config["ff_interm_features"] = len(interm_index) if new_config["ff_interm_features"] == 0: new_config["use_feed_forward"] = False else: prune_linear_layer(self.intermediate_dense, interm_index, "output") self.output_dense.weight.data *= interm_mask prune_linear_layer(self.output_dense, interm_index, "input") self.hard_concrete_for_intermediate = None return new_config class EncoderLayer(Module): """A layer unit in encoder. Combines multihead self attention and feed forward.""" def __init__( self, attention: Optional[Module], # can be None if the entire layer is pruned dropout: float, layer_norm_first: bool, feed_forward: Optional[Module], # can be None if the entire layer is pruned embed_dim: int, ): super().__init__() self.attention = attention self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(embed_dim) self.layer_norm_first = layer_norm_first self.feed_forward = feed_forward self.final_layer_norm = nn.LayerNorm(embed_dim) self.embed_dim = embed_dim def forward( self, x: Tensor, attention_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. attention_mask (Tensor or ``None``, optional): attention mask of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) position_bias (Tensor or ``None``, optional): position bias of shape ``(batch_size * num_heads, src_len, src_len)``. Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. Only used for WavLM model, ignored otherwise. (Default: ``None``) Returns: (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, ``None`` otherwise. """ if self.attention is not None: residual = x if self.layer_norm_first: x = self.layer_norm(x) x, position_bias = self.attention( x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask ) x = self.dropout(x) x = residual + x if self.layer_norm_first: if self.feed_forward is not None: x = x + self.feed_forward(self.final_layer_norm(x)) else: # NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. x = self.layer_norm(x) if self.feed_forward is not None: x = x + self.feed_forward(x) x = self.final_layer_norm(x) return x, position_bias def get_num_params(self): num_params = self.embed_dim * 2 * 2 # two layer norms if self.attention is not None: num_params += self.attention.get_num_params() if self.feed_forward is not None: num_params += self.feed_forward.get_num_params() return num_params class Transformer(Module): def __init__( self, pos_conv_embed: Module, dropout: float, layers: Module, layer_norm_first: bool, layer_drop: float, ): super().__init__() self.pos_conv_embed = pos_conv_embed self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) self.layer_norm_first = layer_norm_first self.layer_drop = layer_drop self.dropout = nn.Dropout(dropout) self.layers = layers def _preprocess(self, x: Tensor): x = x + self.pos_conv_embed(x) if self.layer_norm_first: x = self.layer_norm(x) x = self.dropout(x) return x def forward( self, x: Tensor, attention_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, ) -> Tensor: x = self._preprocess(x) for layer in self.layers: if not (self.training and torch.rand(1).item() <= self.layer_drop): x, position_bias = layer(x, attention_mask, position_bias=position_bias) if not self.layer_norm_first: x = self.layer_norm(x) return x def get_intermediate_outputs( self, x: Tensor, attention_mask: Optional[Tensor] = None, num_layers: Optional[int] = None, position_bias: Optional[Tensor] = None, ) -> List[Tensor]: if num_layers is not None: if not 0 < num_layers <= len(self.layers): raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") ret: List[Tensor] = [] x = self._preprocess(x) for layer in self.layers: x, position_bias = layer(x, attention_mask, position_bias=position_bias) ret.append(x) if num_layers is not None and len(ret) >= num_layers: return ret return ret def get_num_params(self): # pos_conv_embed and layer_norm num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 for layer in self.layers: num_params += layer.get_num_params() return num_params def prune(self): new_config = defaultdict(list) for layer in self.layers: attention_config = layer.attention.prune() new_config["use_attention"].append(attention_config["use_attention"]) if "remaining_heads" in attention_config: new_config["remaining_heads"].append(attention_config["remaining_heads"]) else: new_config["num_heads"].append(attention_config["num_heads"]) if not attention_config["use_attention"]: layer.attention = None ff_config = layer.feed_forward.prune() new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) if not ff_config["use_feed_forward"]: layer.feed_forward = None return new_config class Encoder(Module): def __init__( self, feature_projection: Module, transformer: Module, ): super().__init__() self.feature_projection = feature_projection self.transformer = transformer def _preprocess( self, features: Tensor, lengths: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: x = self.feature_projection(features) mask: Optional[Tensor] = None if lengths is not None: batch_size, max_len, _ = x.shape # create mask for padded elements and zero-out them mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] x[mask] = 0.0 # extend the mask to attention shape and set weight mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) mask = mask.expand(batch_size, 1, max_len, max_len) return x, mask def forward( self, features: Tensor, lengths: Optional[Tensor] = None, ) -> Tensor: x, mask = self._preprocess(features, lengths) x = self.transformer(x, attention_mask=mask) return x def extract_features( self, features: Tensor, lengths: Optional[Tensor] = None, num_layers: Optional[int] = None, ) -> List[Tensor]: x, masks = self._preprocess(features, lengths) interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) return [x] + interm def get_num_params(self, in_features): """Calculate the current model size.""" feature_projection_size = self.feature_projection.get_num_params(in_features) transformer_size = self.transformer.get_num_params() return feature_projection_size + transformer_size def prune(self, conv_out_index): """In-place pruning of submodules.""" prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") transformer_config = self.transformer.prune() return transformer_config ################################################################################ def _get_feature_extractor( norm_mode: str, shapes: List[Tuple[int, int, int]], bias: bool, prune_conv_channels: bool = False, ) -> FeatureExtractor: """ Args: norm_mode (str): Either "group_norm" or "layer_norm". If "group_norm", then a single normalization is applied in the first convolution block. Otherwise, all the convolution blocks will have layer normalization. This option corresponds to "extractor_mode" from fairseq. Expected values are "group_norm" for Base arch, and "layer_norm" for Large arch. shapes (list of tuple of int): Configuration of convolution layers. List of convolution configuration, i.e. ``[(output_channel, kernel_size, stride), ...]`` This option corresponds to "conv_feature_layers" from fairseq. Expected values are ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` for all the architectures. bias (bool): Whether to include bias term to each convolution operation. This option corresponds to "conv_bias" from fairseq. Expected values are False for Base arch, and True for Large arch. See Also: * Original implementation https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 * "extractor_mode" - Def and base: https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 - Large: https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 * "conv_feature_layers" - Def, base and large: https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 * "conv_bias" - Def and base: https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 - Large: https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 """ if norm_mode not in ["group_norm", "layer_norm"]: raise ValueError("Invalid norm mode") blocks = [] in_channels = 1 for i, (out_channels, kernel_size, stride) in enumerate(shapes): normalization = None if norm_mode == "group_norm" and i == 0: normalization = nn.GroupNorm( num_groups=out_channels, num_channels=out_channels, affine=True, ) elif norm_mode == "layer_norm": normalization = LayerNorm( normalized_shape=out_channels, elementwise_affine=True, ) blocks.append( ConvLayerBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=bias, layer_norm=normalization, prune_conv_channels=prune_conv_channels, ) ) in_channels = out_channels return FeatureExtractor(nn.ModuleList(blocks)) def _get_encoder( in_features: int, embed_dim: int, dropout_input: float, pos_conv_kernel: int, pos_conv_groups: int, num_layers: int, use_attention: List[bool], use_feed_forward: List[bool], num_heads: List[int], head_dim: int, attention_dropout: float, ff_interm_features: List[int], ff_interm_dropout: float, dropout: float, layer_norm_first: bool, layer_drop: float, prune_attention_heads: bool = False, prune_attention_layer: bool = False, prune_feed_forward_intermediate: bool = False, prune_feed_forward_layer: bool = False, ) -> Encoder: """ Args: in_features (int): The number of input features. embed_dim (int): The dimension of embedding. This option corresponds to "encoder_embed_dim" from fairseq. Expected values are 768 for Base arch, and 1024 for Large arch. dropout_input (float): The dropout probability applied after the input feature is projected to ``embed_dim``. This option corresponds to "dropout_input" from fairseq. Expected values are 0.1 for both Base and Large arch. pos_conv_kernel (int): The kernel size of convolutional positional embeddings. This option corresponds to "conv_pos" from fairseq. Expected values are 128 for both Base and Large arch. pos_conv_groups (int): The number of groups of convolutional positional embeddings. This option corresponds to "conv_pos_groups" from fairseq. Expected values are 16 for both Base and Large arch. num_layers (int): The number of self attention layers in transformer block. This option corresponds to "encoder_layers" from fairseq. Expected values are 12 for Base and 24 for Large arch. num_heads (int): The number of heads in self attention layers. This option corresponds to "encoder_attention_heads" from fairseq. Expected values are 12 for Base and 16 for Large arch. attention_dropout (float): The dropout probability applied after softmax in self-attention layer. This option corresponds to "attention_dropout" from fairseq. Expected values are 0.1 for Base and 0.0 for Large arch. ff_interm_features (int): The dimension of hidden features in feed forward layer. This option corresponds to "encoder_ffn_embed_dim" from fairseq. Expected values are 3072 for Base and 4096 for Large arch. ff_interm_dropout (float): The dropout probability applied in feedforward layer. This option correspinds to "activation_dropout" from fairseq. Expected values are 0.1 for both Base and Large arch. dropout (float): The dropout probability applied at the end of feed forward layer. This option corresponds to "dropout" from fairseq. Expected values are 0.1 for Base and 0.0 for Large arch. layer_norm_first (bool): Control the order of layer norm in transformer layer and each encoder layer. If True, in transformer layer, layer norm is applied before features are fed to encoder layers. In encoder layer, two layer norms are applied before and after self attention. If False, in transformer layer, layer norm is applied after features are fed to encoder layers. In encoder layer, two layer norms are applied after self attention, before and after feed forward. This option corresponds to "layer_norm_first" from fairseq. Expected values are False for Base and True for Large arch. layer_drop (float): Probability to drop each encoder layer during training. This option corresponds to "layerdrop" from fairseq. Expected values are 0.1 for both Base and Large arch. See Also: * "encoder_embed_dim" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 * "dropout_input" - Def, base and large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 * "conv_pos" - Def, base and large NOTE: The description is wrong. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 - Usage https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 * "conv_pos_groups" - Def, base and large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 * "encoder_layers" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 * "encoder_attention_heads" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 * "attention_dropout" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 * "encoder_ffn_embed_dim" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 * "activation_dropout" - Def https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 - Base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 * "dropout" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 * "layer_norm_first" - Def and base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 * "layerdrop" - Def https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 - Base https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 - Large https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 """ feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) # Original impl # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 encoder_layers = nn.ModuleList() for idx in range(num_layers): if use_attention[idx]: attention = SelfAttention( embed_dim=embed_dim, num_heads=num_heads[idx], head_dim=head_dim, dropout=attention_dropout, prune_heads=prune_attention_heads, prune_layer=prune_attention_layer, ) else: attention = None if use_feed_forward[idx]: feed_forward = FeedForward( io_features=embed_dim, intermediate_features=ff_interm_features[idx], intermediate_dropout=ff_interm_dropout, output_dropout=dropout, prune_intermediate=prune_feed_forward_intermediate, prune_layer=prune_feed_forward_layer, ) else: feed_forward = None encoder_layers.append( EncoderLayer( attention=attention, dropout=dropout, layer_norm_first=layer_norm_first, feed_forward=feed_forward, embed_dim=embed_dim, ) ) transformer = Transformer( pos_conv_embed=pos_conv, dropout=dropout, layers=encoder_layers, layer_norm_first=not layer_norm_first, layer_drop=layer_drop, ) return Encoder(feature_projection, transformer) def _get_wavlm_encoder( in_features: int, embed_dim: int, dropout_input: float, pos_conv_kernel: int, pos_conv_groups: int, num_layers: int, use_attention: List[bool], use_feed_forward: List[bool], total_num_heads: List[int], remaining_heads: List[List[int]], num_buckets: int, max_distance: int, attention_dropout: float, ff_interm_features: List[int], ff_interm_dropout: float, dropout: float, layer_norm_first: bool, layer_drop: float, prune_attention_heads: bool = False, prune_attention_layer: bool = False, prune_feed_forward_intermediate: bool = False, prune_feed_forward_layer: bool = False, ) -> Encoder: """ Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and `max_distance`. Args: in_features (int): See :py:func:`_get_encoder`. embed_dim (int): See :py:func:`_get_encoder`. dropout_input (float): See :py:func:`_get_encoder`. pos_conv_kernel (int): See :py:func:`_get_encoder`. pos_conv_groups (int): See :py:func:`_get_encoder`. num_layers (int): See :py:func:`_get_encoder`. num_heads (int): See :py:func:`_get_encoder`. num_buckets (int): Number of buckets for relative position embedding. max_distance (int): Maximum distance for relative position embedding. attention_dropout (float): See :py:func:`_get_encoder`. ff_interm_features (int): See :py:func:`_get_encoder`. ff_interm_dropout (float): See :py:func:`_get_encoder`. dropout (float): See :py:func:`_get_encoder`. layer_norm_first (bool): See :py:func:`_get_encoder`. layer_drop (float): See :py:func:`_get_encoder`. """ feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) # Original impl # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 encoder_layers = nn.ModuleList() for i in range(num_layers): if use_attention[i]: attention = WavLMSelfAttention( embed_dim=embed_dim, total_num_heads=total_num_heads[i], remaining_heads=remaining_heads[i], dropout=attention_dropout, has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. num_buckets=num_buckets, max_distance=max_distance, prune_heads=prune_attention_heads, prune_layer=prune_attention_layer, ) else: attention = None if use_feed_forward[i]: feed_forward = FeedForward( io_features=embed_dim, intermediate_features=ff_interm_features[i], intermediate_dropout=ff_interm_dropout, output_dropout=dropout, prune_intermediate=prune_feed_forward_intermediate, prune_layer=prune_feed_forward_layer, ) else: feed_forward = None encoder_layers.append( EncoderLayer( attention=attention, dropout=dropout, layer_norm_first=layer_norm_first, feed_forward=feed_forward, embed_dim=embed_dim, ) ) transformer = Transformer( pos_conv_embed=pos_conv, dropout=dropout, layers=encoder_layers, layer_norm_first=not layer_norm_first, layer_drop=layer_drop, ) return Encoder(feature_projection, transformer) def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: """Generate the padding mask given the padded input and the lengths Tensors. Args: input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. lengths (Tensor): The lengths Tensor of dimension `[batch,]`. Returns: (Tensor): The padding mask. """ batch_size, max_len, _ = input.shape mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] return mask class GradMultiply(torch.autograd.Function): @staticmethod def forward(ctx, x, scale): ctx.scale = scale res = x.new(x) return res @staticmethod def backward(ctx, grad): return grad * ctx.scale, None