from typing import Optional, Union import jax import jax.numpy as jnp import flax.linen as nn from transformers.modeling_flax_outputs import FlaxCausalLMOutput from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config from transformers.models.wav2vec2.modeling_flax_wav2vec2 import ( FlaxWav2Vec2FeatureEncoder, FlaxWav2Vec2FeatureProjection, FlaxWav2Vec2StableLayerNormEncoder, FlaxWav2Vec2Adapter, FlaxWav2Vec2PreTrainedModel, FlaxWav2Vec2BaseModelOutput, ) class FlaxWav2Vec2Module(nn.Module): config: Wav2Vec2Config dtype: jnp.dtype = jnp.float32 def setup(self): self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype) self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0: self.masked_spec_embed = self.param( "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) ) if self.config.do_stable_layer_norm: self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) else: raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None def __call__( self, input_values, attention_mask=None, mask_time_indices=None, deterministic=True, output_attentions=None, output_hidden_states=None, freeze_feature_encoder=False, return_dict=None, ): extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) # make sure that no loss is computed on padded inputs if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) if mask_time_indices is not None: # apply SpecAugment along time axis with given indices hidden_states = jnp.where( jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), hidden_states, ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return FlaxWav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def _get_feat_extract_output_lengths( self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None ): """ Computes the output length of the convolutional layers """ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) if add_adapter: for _ in range(self.config.num_adapter_layers): input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) return input_lengths def _get_feature_vector_attention_mask( self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None ): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) batch_size = attention_mask.shape[0] attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") return attention_mask class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2Module class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module): config: Wav2Vec2Config dtype: jnp.dtype = jnp.float32 def setup(self): self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) self.classifier = nn.Dense( self.config.num_labels, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) def __call__( self, input_values, attention_mask=None, mask_time_indices=None, deterministic=True, output_attentions=None, output_hidden_states=None, freeze_feature_encoder=False, return_dict=None, ): outputs = self.wav2vec2( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, freeze_feature_encoder=freeze_feature_encoder, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.classifier(hidden_states) if not return_dict: return (logits,) + outputs[2:] return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2ForAudioFrameClassificationModule