JuanJoseMV's picture
add model logic implementation
8f96165
import torch
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
from torch import nn
from transformers import Wav2Vec2Model
class Wav2Vec2EncoderLayer(nn.Module):
def __init__(
self,
config,
i
):
super().__init__()
self.attention = w2v2.Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = w2v2.Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
self.i = i
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Wav2VecWrapper(nn.Module):
def __init__(
self,
config,
):
super(Wav2VecWrapper, self).__init__()
self.config = config
self.backbone_model = Wav2Vec2Model.from_pretrained(
config._name_or_path,
output_hidden_states=config.output_hidden_states,
)
state_dict = self.backbone_model.state_dict()
self.model_config = self.backbone_model.config
self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)])
def forward(self,
input_features: torch.Tensor,
length: torch.Tensor = None,
):
with torch.no_grad():
hidden_states = self.backbone_model.feature_extractor(input_features)
hidden_states = hidden_states.transpose(1, 2)
hidden_states, _ = self.backbone_model.feature_projection(hidden_states)
if length is not None:
length = self.get_feat_extract_output_lengths(length.detach().cpu())
hidden_states = self.backbone_model.encoder(
hidden_states,
output_hidden_states=self.config.output_hidden_states
).hidden_states
return {'encoder_hidden_states': hidden_states, 'length': length}
def get_feat_extract_output_lengths(self, input_length):
def _conv_out_length(input_length, kernel_size, stride):
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
input_length = _conv_out_length(input_length, kernel_size, stride)
return input_length
def prepare_mask(length, shape, dtype):
mask = torch.zeros(
shape, dtype=dtype
)
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return mask