flax_mms_alignment_model / flax_modeling_alignment.py
Aman K
prepared alignment model to be loaded using flaxautomodel
a1c1315
from typing import Optional, Union
import jax
import jax.numpy as jnp
import flax.linen as nn
from transformers.modeling_flax_outputs import FlaxCausalLMOutput
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2FeatureEncoder,
FlaxWav2Vec2FeatureProjection,
FlaxWav2Vec2StableLayerNormEncoder,
FlaxWav2Vec2Adapter,
FlaxWav2Vec2PreTrainedModel,
FlaxWav2Vec2BaseModelOutput,
)
class FlaxWav2Vec2Module(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
self.masked_spec_embed = self.param(
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
)
if self.config.do_stable_layer_norm:
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
# make sure that no loss is computed on padded inputs
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = jnp.where(
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
hidden_states,
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return FlaxWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
return attention_mask
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
self.classifier = nn.Dense(
self.config.num_labels,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[2:]
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForAudioFrameClassificationModule