hubert-large-korean / modeling_flax_hubert.py
hyunwoo3235's picture
load flax hubert model via remote code
928a4d7
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax Hubert model."""
from functools import partial
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from transformers import HubertConfig
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
)
from transformers.utils import ModelOutput, logging
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxHubertOutput(ModelOutput):
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
extract_features: jnp.ndarray = None
class FlaxConvWithWeightNorm(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
features=self.config.hidden_size,
kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups,
dtype=self.dtype,
)
weight_shape = (
self.conv.features,
self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size[0],
)
self.weight_v = self.param(
"weight_v", jax.nn.initializers.he_normal(), weight_shape
)
self.weight_g = self.param(
"weight_g",
lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :],
)
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size[0] // 2
def _get_normed_weights(self):
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
return normed_kernel
def __call__(self, hidden_states):
kernel = self._get_normed_weights()
hidden_states = jnp.pad(
hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))
)
hidden_states = self.conv.apply(
{"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states
)
return hidden_states
class FlaxHubertNoLayerNormConvLayer(nn.Module):
config: HubertConfig
layer_id: int = 0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.in_conv_dim = (
self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
)
self.out_conv_dim = self.config.conv_dim[self.layer_id]
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id],
kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.feat_extract_activation]
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxHubertLayerNormConvLayer(nn.Module):
config: HubertConfig
layer_id: int = 0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.in_conv_dim = (
self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
)
self.out_conv_dim = self.config.conv_dim[self.layer_id]
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id],
kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
dtype=self.dtype,
)
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.activation = ACT2FN[self.config.feat_extract_activation]
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxHubertGroupNormConvLayer(nn.Module):
config: HubertConfig
layer_id: int = 0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.in_conv_dim = (
self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1
)
self.out_conv_dim = self.config.conv_dim[self.layer_id]
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id],
kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxHubertPositionalConvEmbedding(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
self.activation = ACT2FN[self.config.feat_extract_activation]
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
def __call__(self, hidden_states):
hidden_states = hidden_states.transpose((0, 1, 2))
hidden_states = self.conv(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose((0, 1, 2))
return hidden_states
class FlaxConvLayersCollection(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.config.feat_extract_norm == "layer":
self.layers = [
FlaxHubertLayerNormConvLayer(
self.config, layer_id=i, name=str(i), dtype=self.dtype
)
for i in range(self.config.num_feat_extract_layers)
]
elif self.config.feat_extract_norm == "group":
self.layers = [
FlaxHubertGroupNormConvLayer(
self.config, layer_id=0, name=str(0), dtype=self.dtype
)
] + [
FlaxHubertNoLayerNormConvLayer(
self.config, layer_id=i, name=str(i), dtype=self.dtype
)
for i in range(1, self.config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
" 'layer']"
)
def __call__(self, hidden_states):
for i, conv_layer in enumerate(self.layers):
hidden_states = conv_layer(hidden_states)
return hidden_states
class FlaxHubertFeatureEncoder(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
def __call__(self, input_values, freeze_feature_encoder=False):
hidden_states = input_values[:, :, None]
hidden_states = self.conv_layers(hidden_states)
if freeze_feature_encoder:
hidden_states = jax.lax.stop_gradient(hidden_states)
return hidden_states
class FlaxHubertFeatureProjection(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.feat_proj_layer_norm = self.config.feat_proj_layer_norm
if self.feat_proj_layer_norm:
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.projection = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
def __call__(self, hidden_states, deterministic=True):
if self.feat_proj_layer_norm:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxHubertAttention(nn.Module):
config: HubertConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
def _split_heads(self, hidden_states):
return hidden_states.reshape(
hidden_states.shape[:2] + (self.num_heads, self.head_dim)
)
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Input shape: Batch x Time x Channel"""
# get query, key, value proj for self_attention
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
if attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(
self.dtype
),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class FlaxHubertFeedForward(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.intermediate_dropout = nn.Dropout(self.config.activation_dropout)
self.intermediate_dense = nn.Dense(
self.config.intermediate_size, dtype=self.dtype
)
if isinstance(self.config.hidden_act, str):
self.intermediate_activation = ACT2FN[self.config.hidden_act]
else:
self.intermediate_activation = self.config.hidden_act
self.output_dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.output_dropout = nn.Dropout(self.config.activation_dropout)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_activation(hidden_states)
hidden_states = self.intermediate_dropout(
hidden_states, deterministic=deterministic
)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxHubertEncoderLayer(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxHubertAttention(
config=self.config,
embed_dim=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.hidden_dropout)
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: bool = False,
deterministic=True,
):
attn_residual = hidden_states
hidden_states, attn_weights = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(
hidden_states, deterministic=deterministic
)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxHubertEncoderLayerStableLayerNorm(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxHubertAttention(
config=self.config,
embed_dim=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.hidden_dropout)
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: bool = False,
deterministic=True,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(
self.final_layer_norm(hidden_states), deterministic=deterministic
)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxHubertLayerCollection(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxHubertEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class FlaxHubertEncoder(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.pos_conv_embed = FlaxHubertPositionalConvEmbedding(
self.config, dtype=self.dtype
)
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
self.layers = FlaxHubertLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states = jnp.where(
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape),
hidden_states,
0,
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.layer_norm(outputs[0])
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict:
outputs = (last_hidden_state, hidden_states) + (
outputs[2:] if output_hidden_states else outputs[1:]
)
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxHubertLayerStableLayerNormCollection(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxHubertEncoderLayerStableLayerNorm(
self.config, name=str(i), dtype=self.dtype
)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class FlaxHubertEncoderStableLayerNorm(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.pos_conv_embed = FlaxHubertPositionalConvEmbedding(
self.config, dtype=self.dtype
)
self.layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
self.layers = FlaxHubertLayerStableLayerNormCollection(
self.config, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
if attention_mask is not None:
hidden_states = jnp.where(
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape),
hidden_states,
0,
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.layer_norm(outputs[0])
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict:
outputs = (last_hidden_state, hidden_states) + (
outputs[2:] if output_hidden_states else outputs[1:]
)
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxHubertPreTrainedModel(FlaxPreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
module_class: nn.Module = None
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(
self,
config: HubertConfig,
input_shape: Tuple = (1, 1024),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(
config,
module,
input_shape=input_shape,
seed=seed,
dtype=dtype,
_do_init=_do_init,
)
def init_weights(
self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None
) -> FrozenDict:
input_values = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_values)
params_rng, dropout_rng = jax.random.split(rng, 2)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs, input_values, attention_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
batch_size, sequence_length = input_values.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
return self.module.apply(
inputs,
jnp.array(input_values, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
mask_time_indices,
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
class FlaxHubertModule(nn.Module):
config: HubertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.feature_extractor = FlaxHubertFeatureEncoder(self.config, dtype=self.dtype)
self.feature_projection = FlaxHubertFeatureProjection(
self.config, dtype=self.dtype
)
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
self.masked_spec_embed = self.param(
"masked_spec_embed",
nn.initializers.uniform(dtype=self.dtype),
(self.config.hidden_size,),
)
if self.config.do_stable_layer_norm:
self.encoder = FlaxHubertEncoderStableLayerNorm(self.config)
else:
self.encoder = FlaxHubertEncoder(self.config)
def __call__(
self,
input_values: Optional[jnp.ndarray],
attention_mask: Optional[jnp.ndarray] = None,
mask_time_indices: Optional[jnp.ndarray] = None,
deterministic: bool = True,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, FlaxHubertOutput]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
extract_features = self.feature_extractor(input_values, freeze_feature_encoder)
if attention_mask is not None:
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask
)
hidden_states = self.feature_projection(
extract_features, deterministic=deterministic
)
if mask_time_indices is not None:
hidden_states = jnp.where(
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
jnp.broadcast_to(
self.masked_spec_embed[None, None, :], hidden_states.shape
),
hidden_states,
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return FlaxHubertOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
extract_features=extract_features,
)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _conv_out_length(input_length, kernel_size, stride):
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(
self.config.conv_kernel, self.config.conv_stride
):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray
):
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype
)
attention_mask = attention_mask.at[
jnp.arange(attention_mask.shape[0]), output_lengths - 1
].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype(
"bool"
)
return attention_mask
class FlaxHubertModel(FlaxHubertPreTrainedModel):
module_class = FlaxHubertModule