|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Flax whisper model.""" |
|
|
|
import random |
|
from functools import partial |
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
|
from flax.linen import combine_masks, make_causal_mask |
|
from flax.linen.attention import dot_product_attention_weights |
|
from flax.linen.partitioning import remat, scan_with_axes |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
from transformers import WhisperConfig |
|
from transformers.generation.flax_logits_process import ( |
|
FlaxLogitsProcessor, |
|
FlaxLogitsProcessorList, |
|
FlaxWhisperTimeStampLogitsProcessor, |
|
) |
|
from transformers.modeling_flax_outputs import ( |
|
FlaxBaseModelOutput, |
|
FlaxBaseModelOutputWithPastAndCrossAttentions, |
|
FlaxCausalLMOutputWithCrossAttentions, |
|
FlaxSeq2SeqLMOutput, |
|
FlaxSeq2SeqModelOutput, |
|
) |
|
from transformers.modeling_flax_utils import ( |
|
ACT2FN, |
|
FlaxPreTrainedModel, |
|
append_call_sample_docstring, |
|
append_replace_return_docstrings, |
|
overwrite_call_docstring, |
|
) |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
|
|
from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" |
|
_CONFIG_FOR_DOC = "WhisperConfig" |
|
|
|
|
|
WHISPER_START_DOCSTRING = r""" |
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) This model is also a Flax Linen |
|
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
|
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
|
Finally, this model supports inherent JAX features such as: |
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
|
|
|
Parameters: |
|
config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
|
`jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision |
|
inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. |
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model |
|
parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] |
|
and [`~FlaxPreTrainedModel.to_bf16`]. |
|
""" |
|
|
|
WHISPER_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): |
|
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by |
|
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via |
|
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the |
|
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a |
|
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] |
|
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but |
|
is not used. By default the silence in the input log mel spectrogram are ignored. |
|
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using |
|
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. |
|
[What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as |
|
the starting token for `decoder_input_ids` generation. |
|
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 |
|
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. |
|
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't |
|
use masking, but this argument is preserved for compatibility. By default the silence in the input log mel |
|
spectrogram are ignored. |
|
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
|
range `[0, config.max_position_embeddings - 1]`. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
WHISPER_ENCODE_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): |
|
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by |
|
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via |
|
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the |
|
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a |
|
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]. |
|
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but |
|
is not used. By default the silence in the input log mel spectrogram are ignored. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
WHISPER_DECODE_INPUTS_DOCSTRING = r""" |
|
Args: |
|
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`): |
|
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using |
|
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. |
|
[What are decoder input IDs?](../glossary#decoder-input-ids) |
|
encoder_outputs (`tuple(tuple(numpy.ndarray)`): |
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) |
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of |
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
|
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, |
|
but it is not used. By default the silence in the input log mel spectrogram are ignored. |
|
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 |
|
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. |
|
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
|
range `[0, config.max_position_embeddings - 1]`. |
|
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): |
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor): |
|
r""" |
|
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to |
|
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens |
|
to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit |
|
processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens. |
|
|
|
Args: |
|
force_token_map (`list`): |
|
Map giving token ids and indices where they will be forced to be sampled. |
|
""" |
|
|
|
def __init__(self, force_token_map): |
|
|
|
|
|
force_token_map = jnp.array(force_token_map) |
|
|
|
|
|
|
|
|
|
|
|
|
|
force_token_array = jnp.ones(3, dtype=jnp.int32) * -1 |
|
for index, token in force_token_map: |
|
force_token_array = force_token_array.at[index].set(token) |
|
self.force_token_array = jnp.int32(force_token_array) |
|
|
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: |
|
def _force_token(generation_idx): |
|
batch_size = scores.shape[0] |
|
current_token = self.force_token_array[generation_idx] |
|
|
|
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf") |
|
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype) |
|
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token)) |
|
return new_scores |
|
|
|
scores = lax.cond( |
|
cur_len >= self.force_token_array.shape[0], |
|
|
|
lambda: scores, |
|
|
|
lambda: lax.cond( |
|
self.force_token_array[cur_len] >= 0, |
|
|
|
lambda: _force_token(cur_len), |
|
|
|
lambda: scores, |
|
), |
|
) |
|
return scores |
|
|
|
|
|
class FlaxWhisperAttention(nn.Module): |
|
config: WhisperConfig |
|
embed_dim: int |
|
num_heads: int |
|
dropout: float = 0.0 |
|
causal: bool = False |
|
bias: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
"embed_dim must be divisible by num_heads (got `embed_dim`:" |
|
f" {self.embed_dim} and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
dense = partial( |
|
DenseGeneral, |
|
self.embed_dim, |
|
axis=-1, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("embed", "joined_kv"), |
|
) |
|
|
|
self.q_proj = dense(use_bias=self.bias) |
|
self.k_proj = dense(use_bias=False) |
|
self.v_proj = dense(use_bias=self.bias) |
|
|
|
self.out_proj = DenseGeneral( |
|
self.embed_dim, |
|
axis=-1, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("joined_kv", "embed"), |
|
use_bias=self.bias, |
|
) |
|
|
|
if self.causal: |
|
self.causal_mask = make_causal_mask( |
|
jnp.ones((1, self.config.max_target_positions), dtype="bool"), |
|
dtype="bool", |
|
) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
key_value_states: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
is_cross_attention = key_value_states is not None |
|
batch_size = hidden_states.shape[0] |
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
if is_cross_attention: |
|
key_states = self.k_proj(key_value_states) |
|
value_states = self.v_proj(key_value_states) |
|
else: |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = self._split_heads(query_states) |
|
key_states = self._split_heads(key_states) |
|
value_states = self._split_heads(value_states) |
|
|
|
query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv")) |
|
key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv")) |
|
value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv")) |
|
|
|
if self.causal: |
|
query_length, key_length = query_states.shape[1], key_states.shape[1] |
|
if self.has_variable("cache", "cached_key"): |
|
mask_shift = self.variables["cache"]["cache_index"] |
|
|
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[-1] |
|
causal_mask = lax.dynamic_slice( |
|
self.causal_mask, |
|
(0, 0, mask_shift, 0), |
|
(1, 1, query_length, max_decoder_length), |
|
) |
|
else: |
|
causal_mask = self.causal_mask[:, :, :query_length, :key_length] |
|
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) |
|
|
|
|
|
if attention_mask is not None and self.causal: |
|
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) |
|
attention_mask = combine_masks(attention_mask, causal_mask) |
|
elif self.causal: |
|
attention_mask = causal_mask |
|
elif attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
|
|
|
|
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): |
|
key_states, value_states, attention_mask = self._concatenate_to_cache( |
|
key_states, value_states, query_states, attention_mask |
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query_states, |
|
key_states, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout, |
|
broadcast_dropout=True, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
def _split_heads(self, hidden_state) -> jnp.ndarray: |
|
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_state) -> jnp.ndarray: |
|
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) |
|
|
|
@nn.compact |
|
def _concatenate_to_cache(self, key, value, query, attention_mask): |
|
|
|
is_initialized = self.has_variable("cache", "cached_key") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def swap_dims(x): |
|
return x[:-3] + tuple(x[i] for i in [-2, -1, -3]) |
|
|
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype) |
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype) |
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) |
|
|
|
if is_initialized: |
|
batch_size, num_heads, head_dim, seq_length = cached_key.value.shape |
|
|
|
|
|
|
|
num_updated_cache_vectors = query.shape[1] |
|
expected_shape = (batch_size, 1, num_heads, head_dim) |
|
if num_updated_cache_vectors == 1 and expected_shape != query.shape: |
|
raise ValueError( |
|
"Autoregressive cache shape error, expected query shape" |
|
f" {expected_shape} instead got {query.shape}" |
|
) |
|
|
|
|
|
cur_index = cache_index.value |
|
|
|
|
|
|
|
|
|
|
|
|
|
one_token_key = jnp.moveaxis(key, -3, -1) |
|
one_token_value = jnp.moveaxis(value, -3, -1) |
|
|
|
|
|
|
|
|
|
if num_updated_cache_vectors > 1: |
|
indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None] |
|
key = cached_key.value + jnp.matmul(one_token_key, indices) |
|
value = cached_value.value + jnp.matmul(one_token_value, indices) |
|
else: |
|
one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype) |
|
key = cached_key.value + one_token_key * one_hot_indices |
|
value = cached_value.value + one_token_value * one_hot_indices |
|
|
|
cached_key.value = key |
|
cached_value.value = value |
|
cache_index.value = cache_index.value + num_updated_cache_vectors |
|
|
|
|
|
key = jnp.moveaxis(key, -1, -3) |
|
value = jnp.moveaxis(value, -1, -3) |
|
|
|
|
|
|
|
|
|
pad_mask = jnp.broadcast_to( |
|
jnp.arange(seq_length) < cur_index + num_updated_cache_vectors, |
|
(batch_size,) + (1, num_updated_cache_vectors, seq_length), |
|
) |
|
attention_mask = combine_masks(pad_mask, attention_mask) |
|
|
|
return key, value, attention_mask |
|
|
|
|
|
class FlaxWhisperEncoderLayer(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.d_model |
|
self.self_attn = FlaxWhisperAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.encoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
self.fc1 = DenseGeneral( |
|
self.config.encoder_ffn_dim, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("embed", "mlp"), |
|
) |
|
self.fc2 = DenseGeneral( |
|
self.embed_dim, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("mlp", "embed"), |
|
) |
|
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
output_attentions: bool = True, |
|
deterministic: bool = True, |
|
all_hidden_states=None, |
|
) -> Tuple[jnp.ndarray]: |
|
if self.use_scan: |
|
hidden_states = hidden_states[0] |
|
|
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed")) |
|
|
|
residual = hidden_states |
|
|
|
layernorm_output = self.self_attn_layer_norm(hidden_states) |
|
layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed")) |
|
|
|
attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask) |
|
attn_output = self.dropout_layer(attn_output, deterministic=deterministic) |
|
attn_output = residual + attn_output |
|
attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed")) |
|
|
|
residual = attn_output |
|
|
|
post_layer_norm = self.final_layer_norm(attn_output) |
|
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed")) |
|
|
|
fc1_output = self.activation_fn(self.fc1(post_layer_norm)) |
|
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic) |
|
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp")) |
|
|
|
hidden_states = self.fc2(fc1_output) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed")) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
if self.use_scan: |
|
if all_hidden_states is not None: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
outputs = ( |
|
outputs, |
|
all_hidden_states, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxWhisperEncoderLayerCollection(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
FlaxWhisperEncoderCheckpointLayer = ( |
|
remat( |
|
FlaxWhisperEncoderLayer, |
|
static_argnums=(2, 3), |
|
prevent_cse=not self.use_scan, |
|
) |
|
if self.gradient_checkpointing |
|
else FlaxWhisperEncoderLayer |
|
) |
|
|
|
if self.use_scan: |
|
if output_attentions: |
|
raise ValueError("Cannot use `scan` with `output_attentions` set to True") |
|
|
|
|
|
|
|
input_hidden_states = hidden_states |
|
hidden_states = (hidden_states,) |
|
|
|
hidden_states, all_hidden_states = scan_with_axes( |
|
FlaxWhisperEncoderCheckpointLayer, |
|
variable_axes={"params": 0, "cache": 0}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=( |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
), |
|
variable_carry="all_hidden_states", |
|
length=self.config.encoder_layers, |
|
)( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=True, |
|
name="FlaxEncoderScanLayers", |
|
)( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
deterministic, |
|
all_hidden_states, |
|
) |
|
|
|
|
|
hidden_states = hidden_states[0] |
|
|
|
if output_hidden_states: |
|
|
|
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]]) |
|
|
|
else: |
|
for layer_idx in range(self.config.encoder_layers): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
dropout_probability = random.uniform(0, 1) |
|
if not deterministic and (dropout_probability < self.config.encoder_layerdrop): |
|
layer_outputs = (None, None) |
|
else: |
|
layer_outputs = FlaxWhisperEncoderCheckpointLayer( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
name=str(layer_idx), |
|
)( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
deterministic, |
|
) |
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
) |
|
|
|
|
|
class FlaxWhisperDecoderLayer(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.d_model |
|
self.self_attn = FlaxWhisperAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
causal=True, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
|
|
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
self.encoder_attn = FlaxWhisperAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
self.fc1 = DenseGeneral( |
|
self.config.decoder_ffn_dim, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("embed", "mlp"), |
|
) |
|
self.fc2 = DenseGeneral( |
|
self.embed_dim, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("mlp", "embed"), |
|
) |
|
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = True, |
|
deterministic: bool = True, |
|
all_hidden_states=None, |
|
) -> Tuple[jnp.ndarray]: |
|
if self.use_scan: |
|
hidden_states = hidden_states[0] |
|
|
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed")) |
|
|
|
residual = hidden_states |
|
|
|
layer_norm_output = self.self_attn_layer_norm(hidden_states) |
|
layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed")) |
|
|
|
|
|
self_attn_output, self_attn_weights = self.self_attn( |
|
hidden_states=layer_norm_output, |
|
attention_mask=attention_mask, |
|
init_cache=init_cache, |
|
) |
|
self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic) |
|
self_attn_output = residual + self_attn_output |
|
self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed")) |
|
|
|
|
|
cross_attn_weights = None |
|
if encoder_hidden_states is not None: |
|
residual = self_attn_output |
|
|
|
encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output) |
|
encoder_layer_norm_output = with_sharding_constraint( |
|
encoder_layer_norm_output, ("batch", "length", "embed") |
|
) |
|
|
|
cross_attn_output, cross_attn_weights = self.encoder_attn( |
|
hidden_states=encoder_layer_norm_output, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
) |
|
cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic) |
|
cross_attn_output = residual + cross_attn_output |
|
cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed")) |
|
|
|
|
|
residual = cross_attn_output |
|
|
|
post_layer_norm = self.final_layer_norm(cross_attn_output) |
|
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed")) |
|
|
|
fc1_output = self.activation_fn(self.fc1(post_layer_norm)) |
|
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic) |
|
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp")) |
|
|
|
hidden_states = self.fc2(fc1_output) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed")) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
if self.use_scan: |
|
if all_hidden_states is not None: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
outputs = ( |
|
outputs, |
|
all_hidden_states, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxWhisperDecoderLayerCollection(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
|
|
FlaxWhisperDecoderCheckpointLayer = ( |
|
remat( |
|
FlaxWhisperDecoderLayer, |
|
static_argnums=(4, 5, 6), |
|
prevent_cse=not self.use_scan, |
|
) |
|
if self.gradient_checkpointing |
|
else FlaxWhisperDecoderLayer |
|
) |
|
|
|
if self.use_scan: |
|
if output_attentions: |
|
raise ValueError("Cannot use `scan` with `output_attentions` set to True") |
|
|
|
input_hidden_states = hidden_states |
|
hidden_states = (hidden_states,) |
|
|
|
hidden_states, all_hidden_states = scan_with_axes( |
|
FlaxWhisperDecoderCheckpointLayer, |
|
variable_axes={"params": 0, "cache": 0}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=( |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
nn.broadcast, |
|
), |
|
variable_carry="all_hidden_states", |
|
length=self.config.decoder_layers, |
|
)( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=True, |
|
name="FlaxDecoderScanLayers", |
|
)( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
init_cache, |
|
output_attentions, |
|
deterministic, |
|
all_hidden_states, |
|
) |
|
hidden_states = hidden_states[0] |
|
|
|
if output_hidden_states: |
|
|
|
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]]) |
|
|
|
else: |
|
for layer_idx in range(self.config.decoder_layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
dropout_probability = random.uniform(0, 1) |
|
if not deterministic and (dropout_probability < self.config.decoder_layerdrop): |
|
layer_outputs = (None, None, None) |
|
else: |
|
layer_outputs = FlaxWhisperDecoderCheckpointLayer( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
name=str(layer_idx), |
|
)( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
init_cache, |
|
output_attentions, |
|
deterministic, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = [ |
|
hidden_states, |
|
all_hidden_states, |
|
all_self_attns, |
|
all_cross_attentions, |
|
] |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class FlaxWhisperEncoder(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
def setup(self) -> None: |
|
self.conv1 = Conv( |
|
self.config.d_model, |
|
kernel_size=(3,), |
|
padding=1, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("channels", "num_mel", "embed"), |
|
) |
|
self.conv2 = Conv( |
|
self.config.d_model, |
|
kernel_size=(3,), |
|
strides=2, |
|
padding=1, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("channels", "embed", "num_mel"), |
|
) |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
self.layers = FlaxWhisperEncoderLayerCollection( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
self.embed_positions = Embed( |
|
self.config.max_source_positions, |
|
self.config.d_model, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
|
|
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype) |
|
|
|
def __call__( |
|
self, |
|
input_features: jnp.ndarray, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
if input_features.shape[1:] != ( |
|
self.config.num_mel_bins, |
|
self.config.max_source_positions * 2, |
|
): |
|
raise ValueError( |
|
"input_features.shape[1:], must be equal to (self.config.num_mel_bins," |
|
" self.config.max_source_positions * 2) (got" |
|
f" {input_features.shape[1:]}, but should be" |
|
f" ({self.config.num_mel_bins}," |
|
f" {self.config.max_source_positions * 2}))" |
|
) |
|
|
|
input_features = input_features.transpose(0, 2, 1) |
|
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) |
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel")) |
|
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) |
|
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed")) |
|
|
|
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) |
|
|
|
embed_positions = jax.lax.stop_gradient(embed_positions) |
|
hidden_states = hidden_states + embed_positions |
|
|
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask=None, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_states = outputs[0] |
|
last_hidden_states = self.layer_norm(last_hidden_states) |
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
hidden_states = outputs[1] |
|
if self.use_scan: |
|
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]]) |
|
else: |
|
hidden_states = hidden_states[:-1] + (last_hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=last_hidden_states, |
|
hidden_states=hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxWhisperDecoder(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
def setup(self) -> None: |
|
self.embed_tokens = Embed( |
|
self.config.vocab_size, |
|
self.config.d_model, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
self.embed_positions = Embed( |
|
self.config.max_target_positions, |
|
self.config.d_model, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
) |
|
|
|
self.layers = FlaxWhisperDecoderLayerCollection( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
position_ids: jnp.ndarray, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
input_embeds = self.embed_tokens(input_ids) |
|
position_embeds = self.embed_positions(position_ids) |
|
|
|
hidden_states = input_embeds + position_embeds |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_states = outputs[0] |
|
last_hidden_states = self.layer_norm(last_hidden_states) |
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
hidden_states = outputs[1] |
|
if self.use_scan: |
|
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]]) |
|
else: |
|
hidden_states = hidden_states[:-1] + (last_hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=last_hidden_states, |
|
hidden_states=hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
|
|
class FlaxWhisperModule(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
def setup(self) -> None: |
|
self.encoder = FlaxWhisperEncoder( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
self.decoder = FlaxWhisperDecoder( |
|
self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_features: jnp.ndarray, |
|
decoder_input_ids: jnp.ndarray, |
|
decoder_attention_mask: jnp.ndarray, |
|
decoder_position_ids: jnp.ndarray, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
freeze_encoder: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
encoder_outputs = self.encoder( |
|
input_features, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
if freeze_encoder: |
|
encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states) |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return FlaxSeq2SeqModelOutput( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def _get_encoder_module(self): |
|
return self.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.decoder |
|
|
|
|
|
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = WhisperConfig |
|
base_model_prefix: str = "model" |
|
main_input_name = "input_features" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: WhisperConfig, |
|
input_shape: Tuple[int, int, int] = None, |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
params_dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
|
|
use_scan: bool = False, |
|
gradient_checkpointing: bool = False, |
|
**kwargs, |
|
): |
|
self.use_scan = use_scan |
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
module = self.module_class( |
|
config=config, |
|
dtype=dtype, |
|
params_dtype=params_dtype, |
|
use_scan=use_scan, |
|
gradient_checkpointing=gradient_checkpointing, |
|
**kwargs, |
|
) |
|
|
|
if input_shape is None: |
|
input_shape = (1, 80, 2 * config.max_source_positions) |
|
|
|
super().__init__( |
|
config, |
|
module, |
|
input_shape=input_shape, |
|
seed=seed, |
|
dtype=dtype, |
|
_do_init=_do_init, |
|
) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
|
|
|
input_features = jnp.zeros(input_shape, dtype="f4") |
|
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) |
|
|
|
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
random_params = self.module.init( |
|
rngs, |
|
input_features=input_features, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
decoder_position_ids=decoder_position_ids, |
|
)["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = set() |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def enable_gradient_checkpointing(self): |
|
self.gradient_checkpointing = True |
|
self._module = self.module_class( |
|
config=self.config, |
|
dtype=self.dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
|
|
def enable_scan(self): |
|
self.use_scan = True |
|
self._module = self.module_class( |
|
config=self.config, |
|
dtype=self.dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
init_fn = partial(self.init_weights, input_shape=self.input_shape) |
|
params_shape_tree = jax.eval_shape(init_fn, self.key) |
|
|
|
|
|
self._params_shape_tree = params_shape_tree |
|
|
|
|
|
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) |
|
|
|
|
|
if self._is_initialized: |
|
self.params = self.convert_unroll_to_scan(self.params) |
|
|
|
def disable_scan(self): |
|
self.use_scan = False |
|
self._module = self.module_class( |
|
config=self.config, |
|
dtype=self.dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
init_fn = partial(self.init_weights, input_shape=self.input_shape) |
|
params_shape_tree = jax.eval_shape(init_fn, self.key) |
|
|
|
|
|
self._params_shape_tree = params_shape_tree |
|
|
|
|
|
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) |
|
|
|
|
|
if self._is_initialized: |
|
self.params = self.convert_scan_to_unroll(self.params) |
|
|
|
def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]): |
|
r""" |
|
Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used |
|
to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not |
|
convert the `params` in place. |
|
|
|
To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query |
|
projection params is as follows: |
|
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn', |
|
'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj') |
|
This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super' |
|
matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model: |
|
('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj') |
|
|
|
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With |
|
_do_init=False, it will have to be called explicitly (see example below). |
|
|
|
Arguments: |
|
params (`Union[Dict, FrozenDict]`): |
|
A `PyTree` of model parameters. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from distil_whisper import FlaxWhisperForConditionalGeneration |
|
|
|
>>> # Download model and configuration from huggingface.co |
|
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False) |
|
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method, |
|
>>> # we'll first convert to scan format and then back to unrolled |
|
>>> model.enable_scan() |
|
>>> params = model.convert_unroll_to_scan(params) |
|
>>> # now convert back to unrolled |
|
>>> model.disable_scan() |
|
>>> params = model.convert_scan_to_unroll(params) |
|
```""" |
|
if isinstance(params, FrozenDict): |
|
params = unfreeze(params) |
|
|
|
params = flatten_dict(params, sep="/") |
|
keys = list(params.keys()) |
|
|
|
for k in keys: |
|
|
|
|
|
if "layers/0" in k: |
|
if "decoder" in k: |
|
block_prefix = "Decoder" |
|
num_hidden_layers = self.config.decoder_layers |
|
else: |
|
block_prefix = "Encoder" |
|
num_hidden_layers = self.config.encoder_layers |
|
|
|
|
|
|
|
scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers") |
|
stacked_params = [] |
|
|
|
|
|
for i in range(num_hidden_layers): |
|
|
|
|
|
|
|
unrolled_layer = params.pop(k.replace("0", str(i))) |
|
stacked_params.append(unrolled_layer) |
|
|
|
params[scan_key] = jnp.stack(stacked_params) |
|
|
|
|
|
params = unflatten_dict(params, sep="/") |
|
return params |
|
|
|
def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]): |
|
r""" |
|
Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be |
|
used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does |
|
not convert the `params` in place. |
|
|
|
To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query |
|
projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers: |
|
('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj') |
|
|
|
This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the |
|
scanned matrix of parameteres on the fly: |
|
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn', |
|
'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj') |
|
|
|
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With |
|
_do_init=False, it will have to be called explicitly (see example below). |
|
|
|
Arguments: |
|
params (`Union[Dict, FrozenDict]`): |
|
A `PyTree` of model parameters. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from distil_whisper import FlaxWhisperForConditionalGeneration |
|
|
|
>>> # Download model and configuration from huggingface.co |
|
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False) |
|
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method, |
|
>>> # we'll first convert to scan format and then back to unrolled |
|
>>> model.enable_scan() |
|
>>> params = model.convert_unroll_to_scan(params) |
|
>>> # now convert back to unrolled |
|
>>> model.disable_scan() |
|
>>> params = model.convert_scan_to_unroll(params) |
|
```""" |
|
|
|
if isinstance(params, FrozenDict): |
|
params = unfreeze(params) |
|
|
|
params = flatten_dict(params, sep="/") |
|
keys = list(params.keys()) |
|
|
|
for k in keys: |
|
|
|
|
|
if "FlaxEncoderScanLayers" in k: |
|
|
|
scan_layer = params.pop(k) |
|
|
|
|
|
|
|
for i in range(self.config.encoder_layers): |
|
|
|
|
|
|
|
unrolled_key = k.replace("FlaxEncoderScanLayers", str(i)) |
|
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:] |
|
|
|
elif "FlaxDecoderScanLayers" in k: |
|
|
|
scan_layer = params.pop(k) |
|
|
|
|
|
|
|
for i in range(self.config.decoder_layers): |
|
|
|
|
|
|
|
unrolled_key = k.replace("FlaxDecoderScanLayers", str(i)) |
|
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:] |
|
|
|
params = unflatten_dict(params, sep="/") |
|
return params |
|
|
|
|
|
def init_cache(self, batch_size, max_length, encoder_outputs): |
|
r""" |
|
Args: |
|
batch_size (`int`): |
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
|
max_length (`int`): |
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
|
cache. |
|
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
|
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: |
|
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) |
|
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the |
|
cross-attention of the decoder. |
|
""" |
|
|
|
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), |
|
decoder_input_ids.shape, |
|
) |
|
|
|
def _decoder_forward( |
|
module, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
decoder_position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
init_cache=True, |
|
method=_decoder_forward, |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig) |
|
def encode( |
|
self, |
|
input_features: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration |
|
>>> from datasets import load_dataset |
|
|
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") |
|
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) |
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") |
|
>>> input_features = inputs.input_features |
|
>>> encoder_outputs = model.encode(input_features=input_features) |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _encoder_forward(module, input_features, **kwargs): |
|
encode_module = module._get_encoder_module() |
|
return encode_module(input_features, **kwargs) |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_features=jnp.array(input_features, dtype="f4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
method=_encoder_forward, |
|
) |
|
|
|
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) |
|
@replace_return_docstrings( |
|
output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, |
|
config_class=WhisperConfig, |
|
) |
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration |
|
>>> from datasets import load_dataset |
|
|
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") |
|
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) |
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") |
|
>>> input_features = inputs.input_features |
|
>>> encoder_outputs = model.encode(input_features=input_features) |
|
>>> decoder_start_token_id = model.config.decoder_start_token_id |
|
|
|
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id |
|
|
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs) |
|
>>> last_decoder_hidden_states = outputs.last_hidden_state |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `decoder_position_ids` when passing" " `past_key_values`.") |
|
|
|
if decoder_attention_mask is not None: |
|
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 |
|
else: |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
**kwargs, |
|
) |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past = outputs |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past = outputs |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) |
|
def __call__( |
|
self, |
|
input_features: jnp.ndarray, |
|
decoder_input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
freeze_encoder: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
|
|
if decoder_position_ids is None: |
|
if decoder_attention_mask is not None: |
|
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 |
|
else: |
|
batch_size, sequence_length = decoder_input_ids.shape |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_features=jnp.array(input_features, dtype="f4"), |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
freeze_encoder=freeze_encoder, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
("The bare Whisper Model transformer outputting raw hidden-states without any" " specific head on top."), |
|
WHISPER_START_DOCSTRING, |
|
) |
|
class FlaxWhisperModel(FlaxWhisperPreTrainedModel): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
module_class = FlaxWhisperModule |
|
|
|
|
|
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) |
|
|
|
|
|
class FlaxWhisperForConditionalGenerationModule(nn.Module): |
|
config: WhisperConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
params_dtype: jnp.dtype = jnp.float32 |
|
use_scan: bool = False |
|
gradient_checkpointing: bool = False |
|
|
|
def setup(self) -> None: |
|
self.model = FlaxWhisperModule( |
|
config=self.config, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
use_scan=self.use_scan, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
) |
|
self.lm_head = DenseGeneral( |
|
self.config.vocab_size, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
params_dtype=self.params_dtype, |
|
kernel_axes=("embed", "vocab"), |
|
) |
|
|
|
def _get_encoder_module(self): |
|
return self.model.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.model.decoder |
|
|
|
def __call__( |
|
self, |
|
input_features, |
|
decoder_input_ids, |
|
decoder_attention_mask: jnp.ndarray = None, |
|
decoder_position_ids: jnp.ndarray = None, |
|
position_ids: jnp.ndarray = None, |
|
attention_mask: jnp.ndarray = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
freeze_encoder: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
outputs = self.model( |
|
input_features=input_features, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
decoder_position_ids=decoder_position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
freeze_encoder=freeze_encoder, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] |
|
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) |
|
else: |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return output |
|
|
|
return FlaxSeq2SeqLMOutput( |
|
logits=lm_logits, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) |
|
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): |
|
module_class = FlaxWhisperForConditionalGenerationModule |
|
|
|
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig) |
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration |
|
>>> from datasets import load_dataset |
|
|
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") |
|
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) |
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") |
|
>>> input_features = inputs.input_features |
|
>>> encoder_outputs = model.encode(input_features=input_features) |
|
>>> decoder_start_token_id = model.config.decoder_start_token_id |
|
|
|
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id |
|
|
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs) |
|
>>> last_decoder_hidden_states = outputs.last_hidden_state |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `decoder_position_ids` when passing" " `past_key_values`.") |
|
|
|
if decoder_attention_mask is not None: |
|
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 |
|
else: |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
outputs = decoder_module( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
**kwargs, |
|
) |
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] |
|
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) |
|
else: |
|
lm_logits = module.lm_head(hidden_states) |
|
|
|
return lm_logits, outputs |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
if past_key_values is None: |
|
lm_logits, decoder_outputs = outputs |
|
else: |
|
(lm_logits, decoder_outputs), past = outputs |
|
|
|
if return_dict: |
|
outputs = FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=decoder_outputs.hidden_states, |
|
attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
) |
|
else: |
|
outputs = (lm_logits,) + decoder_outputs[1:] |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def generate( |
|
self, |
|
input_features, |
|
generation_config=None, |
|
logits_processor=None, |
|
return_timestamps=None, |
|
task=None, |
|
language=None, |
|
is_multilingual=None, |
|
**kwargs, |
|
): |
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
if return_timestamps is not None: |
|
generation_config.return_timestamps = return_timestamps |
|
|
|
if task is not None: |
|
generation_config.task = task |
|
|
|
if is_multilingual is not None: |
|
generation_config.is_multilingual = is_multilingual |
|
|
|
if language is not None: |
|
generation_config.language = language |
|
|
|
if kwargs is not None and "decoder_input_ids" in kwargs: |
|
decoder_input_length = len(kwargs["decoder_input_ids"]) |
|
else: |
|
decoder_input_length = 1 |
|
|
|
forced_decoder_ids = [] |
|
|
|
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: |
|
if hasattr(generation_config, "language"): |
|
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) |
|
else: |
|
forced_decoder_ids.append((1, None)) |
|
|
|
if hasattr(generation_config, "task"): |
|
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) |
|
else: |
|
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) |
|
|
|
if ( |
|
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps |
|
) or return_timestamps: |
|
logits_processor = [ |
|
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) |
|
] |
|
else: |
|
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: |
|
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 |
|
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) |
|
|
|
if len(forced_decoder_ids) > 0: |
|
generation_config.forced_decoder_ids = forced_decoder_ids |
|
|
|
return super().generate( |
|
input_features, |
|
generation_config, |
|
logits_processor=logits_processor, |
|
**kwargs, |
|
) |
|
|
|
def pipeline_generate( |
|
self, |
|
input_features, |
|
forced_decoder_ids, |
|
return_timestamps=False, |
|
generation_config=None, |
|
**kwargs, |
|
): |
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
|
|
generation_config.forced_decoder_ids = None |
|
|
|
logits_processor = FlaxLogitsProcessorList() |
|
|
|
logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids)) |
|
|
|
if hasattr(generation_config, "return_timestamps") and return_timestamps: |
|
logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1)) |
|
|
|
return super().generate( |
|
input_features, |
|
generation_config, |
|
logits_processor=logits_processor, |
|
**kwargs, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
max_length, |
|
attention_mask: Optional[jax.Array] = None, |
|
decoder_attention_mask: Optional[jax.Array] = None, |
|
encoder_outputs=None, |
|
**kwargs, |
|
): |
|
|
|
batch_size, seq_length = decoder_input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if decoder_attention_mask is not None: |
|
position_ids = decoder_attention_mask.cumsum(-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) |
|
else: |
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"encoder_outputs": encoder_outputs, |
|
"encoder_attention_mask": attention_mask, |
|
"decoder_attention_mask": extended_attention_mask, |
|
"decoder_position_ids": position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 |
|
return model_kwargs |
|
|
|
|
|
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r""" |
|
Returns: |
|
|
|
Transcription example: |
|
|
|
```python |
|
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration |
|
>>> from datasets import load_dataset |
|
|
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") |
|
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) |
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") |
|
>>> input_features = inputs.input_features |
|
>>> generated_ids = model.generate(input_ids=input_features) |
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
>>> transcription |
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' |
|
``` |
|
""" |
|
|
|
overwrite_call_docstring( |
|
FlaxWhisperForConditionalGeneration, |
|
WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING, |
|
) |
|
append_replace_return_docstrings( |
|
FlaxWhisperForConditionalGeneration, |
|
output_type=FlaxSeq2SeqLMOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
) |
|
|