yl12053's picture
FC
b181bc5
raw
history blame
58.7 kB
"""Building blocks for speech SSL models supporting pruning.
Originally from:
https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py
"""
from collections import defaultdict
from typing import List, Optional, Tuple
import math
import torch
from torch import nn, Tensor
from torch.nn import Module, Parameter
from .hardconcrete import HardConcrete
from .pruning_utils import (
prune_linear_layer,
prune_conv1d_layer,
prune_layer_norm,
)
def _init_transformer_params(module):
"""
Initialize the weights of Transformer module in Wav2Vec2/HuBERT.
If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02.
If ``bias`` is set to ``True`` in the module, set ``bias`` to 0.
If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02.
If ``padding_idx`` is not None, set the weight of padding to 0.
Note:
Ths method corresponds to
`init_bert_params
<https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_sentence_encoder.py#L21>`__
in the original ``fairseq`` implementation.
"""
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose"""
def forward(self, input: Tensor) -> Tensor:
x = input.transpose(-2, -1)
x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1)
return x
class ConvLayerBlock(Module):
"""Convolution unit of FeatureExtractor"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[Module],
prune_conv_channels: bool = False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.layer_norm = layer_norm
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
)
if prune_conv_channels:
self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01)
else:
self.hard_concrete = None
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
length (Tensor or None, optional): Shape ``[batch, ]``.
Returns:
Tensor: Shape ``[batch, out_channels, out_frames]``.
Optional[Tensor]: Shape ``[batch, ]``.
"""
x = self.conv(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
x = nn.functional.gelu(x)
if self.hard_concrete is not None:
channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,)
x = x * channel_mask.unsqueeze(-1)
if length is not None:
length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
# When input length is 0, the resulting length can be negative. So fix it here.
length = torch.max(torch.zeros_like(length), length)
return x, length
def get_num_params_and_out_channels(self, in_channels):
if self.hard_concrete is not None:
out_channels = self.hard_concrete.l0_norm()
else:
out_channels = self.conv.out_channels
num_params = in_channels * out_channels * self.kernel_size
if self.conv.bias is not None:
num_params += out_channels
if self.layer_norm is not None:
num_params += out_channels * 2
return num_params, out_channels
class FeatureExtractor(Module):
"""Extract features from audio
Args:
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
conv_layers: nn.ModuleList,
):
super().__init__()
self.conv_layers = conv_layers
# NOTE: a dummy weight used to save the soft mask of the last conv layer
self.dummy_weight = nn.Parameter(
torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32),
requires_grad=False
)
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor):
Input Tensor representing a batch of audio,
shape: ``[batch, time]``.
length (Tensor or None, optional):
Valid length of each input sample. shape: ``[batch, ]``.
Returns:
Tensor:
The resulting feature, shape: ``[batch, frame, feature]``
Optional[Tensor]:
Valid length of each output sample. shape: ``[batch, ]``.
"""
if x.ndim != 2:
raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers:
x, length = layer(x, length) # (batch, feature, frame)
x = x.transpose(1, 2) # (batch, frame, feature)
x = x * self.dummy_weight
return x, length
def get_num_params_and_final_out_channels(self):
in_channels = 1
num_params = 0
for layer in self.conv_layers:
layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels)
num_params += layer_params
num_params += in_channels # dummy weight
return num_params, in_channels
def prune(self):
""""Prune conv layers and dummy weight based on hardconcrete parameters.
This is an in-place operation.
"""
new_config = [] # [(output_channel, kernel_size, stride), ...]
for idx, layer in enumerate(self.conv_layers):
if layer.hard_concrete is not None:
assert not layer.hard_concrete.training
mask = layer.hard_concrete() # (out_features,)
index = mask.nonzero().squeeze(-1) # 2D -> 1D
assert len(index) > 0, f"Conv channels pruned to zero at index {idx}"
new_config.append(
(len(index), layer.kernel_size, layer.stride)
)
# prune the current layer
prune_conv1d_layer(layer.conv, index, "output")
if layer.layer_norm is not None:
prune_layer_norm(layer.layer_norm, index)
# prune the next layer
if idx == len(self.conv_layers) - 1:
self.dummy_weight.data *= mask
self.dummy_weight = nn.Parameter(
self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False
)
else:
self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1)
prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input")
layer.hard_concrete = None
else:
new_config.append(
(layer.conv.out_channels, layer.kernel_size, layer.stride)
)
index = torch.arange(layer.conv.out_channels, dtype=torch.long)
return new_config, index
class FeatureProjection(Module):
"""Layer that connects FeatureExtractor and Encoder
Projects features to encoder dimension.
Args:
in_features (int): Input feature dim.
out_features (int): Output feature dim.
dropout (float): Dropout probability.
"""
def __init__(
self,
in_features: int,
out_features: int,
dropout: float,
):
super().__init__()
self.layer_norm = nn.LayerNorm(in_features)
self.projection = nn.Linear(
in_features,
out_features,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x (Tensor):
Feature Tensor. shape: ``[batch, frame, in_feature]``
Returns:
Tensor: Projected features. ``[batch, frame, out_feature]``.
"""
x = self.layer_norm(x)
x = self.projection(x)
x = self.dropout(x)
return x
def get_num_params(self, in_features):
return in_features * 2 + (in_features + 1) * self.projection.out_features
class ConvolutionalPositionalEmbedding(Module):
"""Positional embedding which is placed at the beginning of Transformer.
Args:
embed_dim (int): Feature dimension of the input Tensor.
kernel_size (int): The number of frames to be use.
groups (int): The number of groups in feature dimensions.
"""
def __init__(
self,
embed_dim: int,
kernel_size: int,
groups: int,
):
super().__init__()
self.embed_dim = embed_dim
self.kernel_size = kernel_size
self.conv = nn.Conv1d(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
def __prepare_scriptable__(self):
for hook in self.conv._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
torch.nn.utils.remove_weight_norm(self.conv)
return self
def forward(self, x):
"""
Args:
x (Tensor): shape ``[batch, frame, feature]``.
Returns:
Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
"""
x = x.transpose(-2, -1)
x = self.conv(x)
if self.num_remove > 0:
x = x[..., : -self.num_remove]
x = torch.nn.functional.gelu(x)
x = x.transpose(-2, -1)
return x
class SelfAttention(Module):
"""Multihead Self Attention module
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional):
Dropout probability on attn_output_weights. Default: ``0.0``
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
prune_heads: bool = False, # whether to prune attention heads
prune_layer: bool = False, # whether to prune entire attention layers
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = torch.nn.Dropout(dropout)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True)
self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True)
if prune_heads:
self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01)
else:
self.hard_concrete_for_heads = None
if prune_layer:
self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01)
else:
self.hard_concrete_for_layer = None
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
attention_mask (Tensor or ``None``, optional):
shape: ``[batch_size, 1, sequence_length, sequence_length]``
position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
:py:class:`WavLMSelfAttention`.
Returns:
(Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
with :py:class:`WavLMSelAttention`).
Attention output shape: ``[batch, sequence_length, embed_dim]``.
"""
if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError(
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
)
batch_size, length, embed_dim = x.size()
shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
# scale down q to avoid value overflow.
weights = (self.scaling * q) @ k # B, nH, L, L
if attention_mask is not None:
weights += attention_mask
# subtracting a constant value from the tensor won't change the output of softmax.
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
weights = weights - weights.max(dim=-1, keepdim=True)[0]
weights = torch.nn.functional.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # B, nH, L, Hd
if self.hard_concrete_for_heads is not None:
head_mask = self.hard_concrete_for_heads() # (nH,)
output = output * head_mask.unsqueeze(-1).unsqueeze(-1)
output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim)
output = self.out_proj(output)
if self.hard_concrete_for_layer is not None:
layer_mask = self.hard_concrete_for_layer() # (1,)
output = output * layer_mask
return output, None # Necessary for compatibility with WavLMSelAttention
def get_num_params(self):
if self.hard_concrete_for_heads is not None:
num_heads = self.hard_concrete_for_heads.l0_norm()
else:
num_heads = self.num_heads
num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \
+ (num_heads * self.head_dim + 1) * self.embed_dim
if self.hard_concrete_for_layer is not None:
num_params *= self.hard_concrete_for_layer.l0_norm()
return num_params
def prune(self):
new_config = {
"use_attention": True,
"num_heads": self.num_heads,
}
if self.hard_concrete_for_layer is not None:
assert not self.hard_concrete_for_layer.training
layer_mask = self.hard_concrete_for_layer() # (1,)
self.out_proj.weight.data *= layer_mask
self.out_proj.bias.data *= layer_mask
if layer_mask == 0:
new_config["use_attention"] = False
self.hard_concrete_for_layer = None
if self.hard_concrete_for_heads is not None:
assert not self.hard_concrete_for_heads.training
head_mask = self.hard_concrete_for_heads() # (num_heads,)
new_config["num_heads"] = len(head_mask.nonzero())
if new_config["num_heads"] == 0:
new_config["use_attention"] = False
else:
full_mask = head_mask.repeat_interleave(self.head_dim)
full_index = full_mask.nonzero().squeeze(-1) # 1D
prune_linear_layer(self.k_proj, full_index, "output")
prune_linear_layer(self.v_proj, full_index, "output")
prune_linear_layer(self.q_proj, full_index, "output")
self.out_proj.weight.data *= full_mask
prune_linear_layer(self.out_proj, full_index, "input")
self.hard_concrete_for_heads = None
return new_config
class WavLMSelfAttention(SelfAttention):
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
"""
def __init__(
self,
embed_dim: int,
total_num_heads: int,
remaining_heads: Optional[List[int]] = None,
dropout: float = 0.0,
bias: bool = True,
has_relative_attention_bias: bool = False,
num_buckets: int = 32,
max_distance: int = 128,
gru_rel_pos: bool = True,
prune_heads: bool = False,
prune_layer: bool = False,
):
self.total_num_heads = total_num_heads
if remaining_heads is None:
self.remaining_heads = list(range(total_num_heads))
else:
self.remaining_heads = remaining_heads # list of indices
self.head_dim = embed_dim // total_num_heads
super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if has_relative_attention_bias:
self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads)
else:
self.rel_attn_embed = None
# override linear layers to customize bias
self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias)
self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias)
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1))
self.has_position_bias = True
def compute_bias(self, query_length: int, key_length: int) -> Tensor:
"""Compute relative position embeddings for WavLM model.
Args:
query_length (int): Query position can take values between 0 and ``query_length - 1``.
key_length (int): Key position can take values between 0 and ``key_length - 1``.
Returns:
Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # Shape (query_length, key_length)
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1])
return values
def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
"""Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
paper :cite:`chen2022wavlm`.
Args:
relative_positions (Tensor): Relative offsets between query and key positions,
of shape ``(query_length, key_length)``.
bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
to zero. (Default ``True``)
Returns:
Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
"""
num_buckets = self.num_buckets
max_distance = self.max_distance
# Shape (query_length, key_length)
relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def forward(
self,
query: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
`(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
attn_mask: Needs to be ``None``. The argument exists for compatibility with
``EncoderLayer``. (Default: ``None``)
position_bias (Tensor or None, optional): Position bias of shape
``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
generated in the first layer and then passed from each encoder layer to the next one.
(Default: ``None``)
Returns:
attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
"""
bsz, seq_len, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert key_padding_mask is None
# only for the first layer
if self.rel_attn_embed is not None and position_bias is None:
position_bias = self.compute_bias(seq_len, seq_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len)
attn_mask_rel_pos: Optional[Tensor] = None
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos: # Apply gating on relative position bias
query_layer = query.view(bsz, seq_len, self.total_num_heads, -1)
query_layer = query_layer.permute(0, 2, 1, 3)
gate_a, gate_b = torch.sigmoid(
self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))
attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :]
attn_mask = attn_mask_rel_pos
if attention_mask is not None:
attn_mask = attn_mask + attention_mask
if key_padding_mask is not None:
attn_mask = attn_mask.masked_fill(
key_padding_mask.reshape(bsz, 1, 1, seq_len),
float("-inf")
)
attn_output, _ = super().forward(query, attention_mask=attn_mask)
return attn_output, position_bias
def prune(self):
new_config = {
"use_attention": True,
"remaining_heads": self.remaining_heads,
}
if self.hard_concrete_for_layer is not None:
assert not self.hard_concrete_for_layer.training
layer_mask = self.hard_concrete_for_layer() # (1,)
self.out_proj.weight.data *= layer_mask
self.out_proj.bias.data *= layer_mask
if layer_mask == 0:
new_config["use_attention"] = False
self.hard_concrete_for_layer = None
if self.hard_concrete_for_heads is not None:
assert not self.hard_concrete_for_heads.training
head_mask = self.hard_concrete_for_heads() # (num_heads,)
new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist()
if len(new_config["remaining_heads"]) == 0:
new_config["use_attention"] = False
else:
full_mask = head_mask.repeat_interleave(self.head_dim)
full_index = full_mask.nonzero().squeeze(-1) # 1D
prune_linear_layer(self.k_proj, full_index, "output")
prune_linear_layer(self.v_proj, full_index, "output")
prune_linear_layer(self.q_proj, full_index, "output")
self.out_proj.weight.data *= full_mask
prune_linear_layer(self.out_proj, full_index, "input")
self.hard_concrete_for_heads = None
return new_config
class FeedForward(Module):
"""Layer that follows attention layer in encoder layer."""
def __init__(
self,
io_features: int,
intermediate_features: int,
intermediate_dropout: float,
output_dropout: float,
prune_intermediate: bool = False,
prune_layer: bool = False,
):
super().__init__()
self.intermediate_dense = nn.Linear(io_features, intermediate_features)
self.intermediate_dropout = nn.Dropout(intermediate_dropout)
self.output_dense = nn.Linear(intermediate_features, io_features)
self.output_dropout = nn.Dropout(output_dropout)
if prune_intermediate:
self.hard_concrete_for_intermediate = HardConcrete(
n_in=intermediate_features, init_mean=0.5
)
else:
self.hard_concrete_for_intermediate = None
if prune_layer:
self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01)
else:
self.hard_concrete_for_layer = None
def forward(self, x):
"""
Args:
x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns:
x (Tensor): shape: `(batch, sequence_length, io_features)`
"""
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.intermediate_dropout(x)
if self.hard_concrete_for_intermediate is not None:
intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,)
x = x * intermediate_mask
x = self.output_dense(x)
x = self.output_dropout(x)
if self.hard_concrete_for_layer is not None:
layer_mask = self.hard_concrete_for_layer() # (1,)
x = x * layer_mask
return x
def get_num_params(self):
io_features = self.intermediate_dense.in_features
if self.hard_concrete_for_intermediate is not None:
intermediate_features = self.hard_concrete_for_intermediate.l0_norm()
else:
intermediate_features = self.intermediate_dense.out_features
num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features
if self.hard_concrete_for_layer is not None:
num_params *= self.hard_concrete_for_layer.l0_norm()
return num_params
def prune(self):
new_config = {
"use_feed_forward": True,
"ff_interm_features": self.intermediate_dense.out_features
}
if self.hard_concrete_for_layer is not None:
assert not self.hard_concrete_for_layer.training
layer_mask = self.hard_concrete_for_layer()
self.output_dense.weight.data *= layer_mask
self.output_dense.bias.data *= layer_mask
if layer_mask == 0:
new_config["use_feed_forward"] = False
self.hard_concrete_for_layer = None
if self.hard_concrete_for_intermediate is not None:
assert not self.hard_concrete_for_intermediate.training
interm_mask = self.hard_concrete_for_intermediate()
interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1
new_config["ff_interm_features"] = len(interm_index)
if new_config["ff_interm_features"] == 0:
new_config["use_feed_forward"] = False
else:
prune_linear_layer(self.intermediate_dense, interm_index, "output")
self.output_dense.weight.data *= interm_mask
prune_linear_layer(self.output_dense, interm_index, "input")
self.hard_concrete_for_intermediate = None
return new_config
class EncoderLayer(Module):
"""A layer unit in encoder. Combines multihead self attention and feed forward."""
def __init__(
self,
attention: Optional[Module], # can be None if the entire layer is pruned
dropout: float,
layer_norm_first: bool,
feed_forward: Optional[Module], # can be None if the entire layer is pruned
embed_dim: int,
):
super().__init__()
self.attention = attention
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(embed_dim)
self.layer_norm_first = layer_norm_first
self.feed_forward = feed_forward
self.final_layer_norm = nn.LayerNorm(embed_dim)
self.embed_dim = embed_dim
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
attention_mask (Tensor or ``None``, optional): attention mask
of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
position_bias (Tensor or ``None``, optional): position bias of shape
``(batch_size * num_heads, src_len, src_len)``.
Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
Only used for WavLM model, ignored otherwise. (Default: ``None``)
Returns:
(x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
``None`` otherwise.
"""
if self.attention is not None:
residual = x
if self.layer_norm_first:
x = self.layer_norm(x)
x, position_bias = self.attention(
x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask
)
x = self.dropout(x)
x = residual + x
if self.layer_norm_first:
if self.feed_forward is not None:
x = x + self.feed_forward(self.final_layer_norm(x))
else:
# NOTE: for post norm, the layer norms should always be applied even if the layers are pruned.
x = self.layer_norm(x)
if self.feed_forward is not None:
x = x + self.feed_forward(x)
x = self.final_layer_norm(x)
return x, position_bias
def get_num_params(self):
num_params = self.embed_dim * 2 * 2 # two layer norms
if self.attention is not None:
num_params += self.attention.get_num_params()
if self.feed_forward is not None:
num_params += self.feed_forward.get_num_params()
return num_params
class Transformer(Module):
def __init__(
self,
pos_conv_embed: Module,
dropout: float,
layers: Module,
layer_norm_first: bool,
layer_drop: float,
):
super().__init__()
self.pos_conv_embed = pos_conv_embed
self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
self.layer_norm_first = layer_norm_first
self.layer_drop = layer_drop
self.dropout = nn.Dropout(dropout)
self.layers = layers
def _preprocess(self, x: Tensor):
x = x + self.pos_conv_embed(x)
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.dropout(x)
return x
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
) -> Tensor:
x = self._preprocess(x)
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
if not self.layer_norm_first:
x = self.layer_norm(x)
return x
def get_intermediate_outputs(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
position_bias: Optional[Tensor] = None,
) -> List[Tensor]:
if num_layers is not None:
if not 0 < num_layers <= len(self.layers):
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
ret: List[Tensor] = []
x = self._preprocess(x)
for layer in self.layers:
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
ret.append(x)
if num_layers is not None and len(ret) >= num_layers:
return ret
return ret
def get_num_params(self):
# pos_conv_embed and layer_norm
num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2
for layer in self.layers:
num_params += layer.get_num_params()
return num_params
def prune(self):
new_config = defaultdict(list)
for layer in self.layers:
attention_config = layer.attention.prune()
new_config["use_attention"].append(attention_config["use_attention"])
if "remaining_heads" in attention_config:
new_config["remaining_heads"].append(attention_config["remaining_heads"])
else:
new_config["num_heads"].append(attention_config["num_heads"])
if not attention_config["use_attention"]:
layer.attention = None
ff_config = layer.feed_forward.prune()
new_config["use_feed_forward"].append(ff_config["use_feed_forward"])
new_config["ff_interm_features"].append(ff_config["ff_interm_features"])
if not ff_config["use_feed_forward"]:
layer.feed_forward = None
return new_config
class Encoder(Module):
def __init__(
self,
feature_projection: Module,
transformer: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
def _preprocess(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features)
mask: Optional[Tensor] = None
if lengths is not None:
batch_size, max_len, _ = x.shape
# create mask for padded elements and zero-out them
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)
return x, mask
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
return x
def extract_features(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
x, masks = self._preprocess(features, lengths)
interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
return [x] + interm
def get_num_params(self, in_features):
"""Calculate the current model size."""
feature_projection_size = self.feature_projection.get_num_params(in_features)
transformer_size = self.transformer.get_num_params()
return feature_projection_size + transformer_size
def prune(self, conv_out_index):
"""In-place pruning of submodules."""
prune_layer_norm(self.feature_projection.layer_norm, conv_out_index)
prune_linear_layer(self.feature_projection.projection, conv_out_index, "input")
transformer_config = self.transformer.prune()
return transformer_config
################################################################################
def _get_feature_extractor(
norm_mode: str,
shapes: List[Tuple[int, int, int]],
bias: bool,
prune_conv_channels: bool = False,
) -> FeatureExtractor:
"""
Args:
norm_mode (str):
Either "group_norm" or "layer_norm".
If "group_norm", then a single normalization is applied
in the first convolution block. Otherwise, all the convolution
blocks will have layer normalization.
This option corresponds to "extractor_mode" from fairseq.
Expected values are "group_norm" for Base arch, and
"layer_norm" for Large arch.
shapes (list of tuple of int):
Configuration of convolution layers. List of convolution configuration,
i.e. ``[(output_channel, kernel_size, stride), ...]``
This option corresponds to "conv_feature_layers" from fairseq.
Expected values are
``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
for all the architectures.
bias (bool):
Whether to include bias term to each convolution operation.
This option corresponds to "conv_bias" from fairseq.
Expected values are False for Base arch, and True for Large arch.
See Also:
* Original implementation
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
* "extractor_mode"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
* "conv_feature_layers"
- Def, base and large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
* "conv_bias"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
"""
if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
blocks = []
in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
normalization = None
if norm_mode == "group_norm" and i == 0:
normalization = nn.GroupNorm(
num_groups=out_channels,
num_channels=out_channels,
affine=True,
)
elif norm_mode == "layer_norm":
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)
blocks.append(
ConvLayerBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
layer_norm=normalization,
prune_conv_channels=prune_conv_channels,
)
)
in_channels = out_channels
return FeatureExtractor(nn.ModuleList(blocks))
def _get_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
use_attention: List[bool],
use_feed_forward: List[bool],
num_heads: List[int],
head_dim: int,
attention_dropout: float,
ff_interm_features: List[int],
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
prune_attention_heads: bool = False,
prune_attention_layer: bool = False,
prune_feed_forward_intermediate: bool = False,
prune_feed_forward_layer: bool = False,
) -> Encoder:
"""
Args:
in_features (int): The number of input features.
embed_dim (int):
The dimension of embedding.
This option corresponds to "encoder_embed_dim" from fairseq.
Expected values are 768 for Base arch, and 1024 for Large arch.
dropout_input (float):
The dropout probability applied after the input feature is projected
to ``embed_dim``.
This option corresponds to "dropout_input" from fairseq.
Expected values are 0.1 for both Base and Large arch.
pos_conv_kernel (int):
The kernel size of convolutional positional embeddings.
This option corresponds to "conv_pos" from fairseq.
Expected values are 128 for both Base and Large arch.
pos_conv_groups (int):
The number of groups of convolutional positional embeddings.
This option corresponds to "conv_pos_groups" from fairseq.
Expected values are 16 for both Base and Large arch.
num_layers (int):
The number of self attention layers in transformer block.
This option corresponds to "encoder_layers" from fairseq.
Expected values are 12 for Base and 24 for Large arch.
num_heads (int):
The number of heads in self attention layers.
This option corresponds to "encoder_attention_heads" from fairseq.
Expected values are 12 for Base and 16 for Large arch.
attention_dropout (float):
The dropout probability applied after softmax in self-attention layer.
This option corresponds to "attention_dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
ff_interm_features (int):
The dimension of hidden features in feed forward layer.
This option corresponds to "encoder_ffn_embed_dim" from fairseq.
Expected values are 3072 for Base and 4096 for Large arch.
ff_interm_dropout (float):
The dropout probability applied in feedforward layer.
This option correspinds to "activation_dropout" from fairseq.
Expected values are 0.1 for both Base and Large arch.
dropout (float):
The dropout probability applied at the end of feed forward layer.
This option corresponds to "dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
layer_norm_first (bool):
Control the order of layer norm in transformer layer and each encoder layer.
If True, in transformer layer, layer norm is applied before features are fed
to encoder layers. In encoder layer, two layer norms are applied before and after
self attention.
If False, in transformer layer, layer norm is applied after features are fed
to encoder layers. In encoder layer, two layer norms are applied after self
attention, before and after feed forward.
This option corresponds to "layer_norm_first" from fairseq.
Expected values are False for Base and True for Large arch.
layer_drop (float):
Probability to drop each encoder layer during training.
This option corresponds to "layerdrop" from fairseq.
Expected values are 0.1 for both Base and Large arch.
See Also:
* "encoder_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
* "dropout_input"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
* "conv_pos"
- Def, base and large
NOTE: The description is wrong.
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
- Usage
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
* "conv_pos_groups"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
* "encoder_layers"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
* "encoder_attention_heads"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
* "attention_dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
* "encoder_ffn_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
* "activation_dropout"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
* "dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
* "layer_norm_first"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
* "layerdrop"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for idx in range(num_layers):
if use_attention[idx]:
attention = SelfAttention(
embed_dim=embed_dim,
num_heads=num_heads[idx],
head_dim=head_dim,
dropout=attention_dropout,
prune_heads=prune_attention_heads,
prune_layer=prune_attention_layer,
)
else:
attention = None
if use_feed_forward[idx]:
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features[idx],
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
prune_intermediate=prune_feed_forward_intermediate,
prune_layer=prune_feed_forward_layer,
)
else:
feed_forward = None
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
embed_dim=embed_dim,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _get_wavlm_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
use_attention: List[bool],
use_feed_forward: List[bool],
total_num_heads: List[int],
remaining_heads: List[List[int]],
num_buckets: int,
max_distance: int,
attention_dropout: float,
ff_interm_features: List[int],
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
prune_attention_heads: bool = False,
prune_attention_layer: bool = False,
prune_feed_forward_intermediate: bool = False,
prune_feed_forward_layer: bool = False,
) -> Encoder:
"""
Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are
the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder
is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and
`max_distance`.
Args:
in_features (int): See :py:func:`_get_encoder`.
embed_dim (int): See :py:func:`_get_encoder`.
dropout_input (float): See :py:func:`_get_encoder`.
pos_conv_kernel (int): See :py:func:`_get_encoder`.
pos_conv_groups (int): See :py:func:`_get_encoder`.
num_layers (int): See :py:func:`_get_encoder`.
num_heads (int): See :py:func:`_get_encoder`.
num_buckets (int): Number of buckets for relative position embedding.
max_distance (int): Maximum distance for relative position embedding.
attention_dropout (float): See :py:func:`_get_encoder`.
ff_interm_features (int): See :py:func:`_get_encoder`.
ff_interm_dropout (float): See :py:func:`_get_encoder`.
dropout (float): See :py:func:`_get_encoder`.
layer_norm_first (bool): See :py:func:`_get_encoder`.
layer_drop (float): See :py:func:`_get_encoder`.
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for i in range(num_layers):
if use_attention[i]:
attention = WavLMSelfAttention(
embed_dim=embed_dim,
total_num_heads=total_num_heads[i],
remaining_heads=remaining_heads[i],
dropout=attention_dropout,
has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer.
num_buckets=num_buckets,
max_distance=max_distance,
prune_heads=prune_attention_heads,
prune_layer=prune_attention_layer,
)
else:
attention = None
if use_feed_forward[i]:
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features[i],
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
prune_intermediate=prune_feed_forward_intermediate,
prune_layer=prune_feed_forward_layer,
)
else:
feed_forward = None
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
embed_dim=embed_dim,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
"""Generate the padding mask given the padded input and the lengths Tensors.
Args:
input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
Returns:
(Tensor): The padding mask.
"""
batch_size, max_len, _ = input.shape
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
return mask
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None