|
import torch |
|
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 |
|
from torch import nn |
|
from transformers import Wav2Vec2Model |
|
|
|
class Wav2Vec2EncoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
config, |
|
i |
|
): |
|
super().__init__() |
|
self.attention = w2v2.Wav2Vec2Attention( |
|
embed_dim=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=False, |
|
) |
|
self.dropout = nn.Dropout(config.hidden_dropout) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.feed_forward = w2v2.Wav2Vec2FeedForward(config) |
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.config = config |
|
self.i = i |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
|
attn_residual = hidden_states |
|
|
|
hidden_states, attn_weights, _ = self.attention( |
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions |
|
) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = attn_residual + hidden_states |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = hidden_states + self.feed_forward(hidden_states) |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
return outputs |
|
|
|
class Wav2VecWrapper(nn.Module): |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super(Wav2VecWrapper, self).__init__() |
|
self.config = config |
|
|
|
self.backbone_model = Wav2Vec2Model.from_pretrained( |
|
config._name_or_path, |
|
output_hidden_states=config.output_hidden_states, |
|
) |
|
state_dict = self.backbone_model.state_dict() |
|
|
|
self.model_config = self.backbone_model.config |
|
self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)]) |
|
|
|
def forward(self, |
|
input_features: torch.Tensor, |
|
length: torch.Tensor = None, |
|
): |
|
with torch.no_grad(): |
|
hidden_states = self.backbone_model.feature_extractor(input_features) |
|
hidden_states = hidden_states.transpose(1, 2) |
|
hidden_states, _ = self.backbone_model.feature_projection(hidden_states) |
|
|
|
if length is not None: |
|
length = self.get_feat_extract_output_lengths(length.detach().cpu()) |
|
|
|
hidden_states = self.backbone_model.encoder( |
|
hidden_states, |
|
output_hidden_states=self.config.output_hidden_states |
|
).hidden_states |
|
|
|
return {'encoder_hidden_states': hidden_states, 'length': length} |
|
|
|
def get_feat_extract_output_lengths(self, input_length): |
|
def _conv_out_length(input_length, kernel_size, stride): |
|
return (input_length - kernel_size) // stride + 1 |
|
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): |
|
input_length = _conv_out_length(input_length, kernel_size, stride) |
|
return input_length |
|
|
|
def prepare_mask(length, shape, dtype): |
|
mask = torch.zeros( |
|
shape, dtype=dtype |
|
) |
|
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 |
|
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
|
return mask |