| import torch | |
| import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 | |
| from torch import nn | |
| from transformers import Wav2Vec2Model | |
| class Wav2Vec2EncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| config, | |
| i | |
| ): | |
| super().__init__() | |
| self.attention = w2v2.Wav2Vec2Attention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| is_decoder=False, | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.feed_forward = w2v2.Wav2Vec2FeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.config = config | |
| self.i = i | |
| def forward(self, hidden_states, attention_mask=None, output_attentions=False): | |
| attn_residual = hidden_states | |
| hidden_states, attn_weights, _ = self.attention( | |
| hidden_states, attention_mask=attention_mask, output_attentions=output_attentions | |
| ) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = attn_residual + hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = hidden_states + self.feed_forward(hidden_states) | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class Wav2VecWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| config, | |
| ): | |
| super(Wav2VecWrapper, self).__init__() | |
| self.config = config | |
| self.backbone_model = Wav2Vec2Model.from_pretrained( | |
| config._name_or_path, | |
| output_hidden_states=config.output_hidden_states, | |
| ) | |
| state_dict = self.backbone_model.state_dict() | |
| self.model_config = self.backbone_model.config | |
| self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)]) | |
| def forward(self, | |
| input_features: torch.Tensor, | |
| length: torch.Tensor = None, | |
| ): | |
| with torch.no_grad(): | |
| hidden_states = self.backbone_model.feature_extractor(input_features) | |
| hidden_states = hidden_states.transpose(1, 2) | |
| hidden_states, _ = self.backbone_model.feature_projection(hidden_states) | |
| if length is not None: | |
| length = self.get_feat_extract_output_lengths(length.detach().cpu()) | |
| hidden_states = self.backbone_model.encoder( | |
| hidden_states, | |
| output_hidden_states=self.config.output_hidden_states | |
| ).hidden_states | |
| return {'encoder_hidden_states': hidden_states, 'length': length} | |
| def get_feat_extract_output_lengths(self, input_length): | |
| def _conv_out_length(input_length, kernel_size, stride): | |
| return (input_length - kernel_size) // stride + 1 | |
| for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): | |
| input_length = _conv_out_length(input_length, kernel_size, stride) | |
| return input_length | |
| def prepare_mask(length, shape, dtype): | |
| mask = torch.zeros( | |
| shape, dtype=dtype | |
| ) | |
| mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 | |
| mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() | |
| return mask |