#!/usr/bin/env python3 from transformers.models.wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2Module, FlaxWav2Vec2PreTrainedModel from typing import Union from transformers import HubertConfig from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput import flax.linen as nn import jax.numpy as jnp import jax class FlaxHubertForSequenceClassificationModule(nn.Module): config: HubertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.hubert = FlaxWav2Vec2Module(self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.final_dropout) self.reduce = "mean" # binary classification self.lm_head = nn.Dense( 2, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) def __call__( self, input_values, attention_mask=None, mask_time_indices=None, deterministic=True, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.hubert( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.reduce == "mean": hidden_states = jnp.mean(hidden_states, axis=1) hidden_states = jax.nn.relu(hidden_states) logits = self.lm_head(hidden_states) if not return_dict: return (logits,) + outputs[2:] return FlaxSequenceClassifierOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) return input_lengths class FlaxHubertPreTrainedModel(FlaxWav2Vec2PreTrainedModel): config_class = HubertConfig base_model_prefix: str = "hubert" module_class: nn.Module = None def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): return self.module._get_feat_extract_output_lengths(input_lengths) class FlaxHubertModel(FlaxHubertPreTrainedModel): module_class = FlaxWav2Vec2Module class FlaxHubertForSequenceClassification(FlaxHubertPreTrainedModel): module_class = FlaxHubertForSequenceClassificationModule