training / flax /distil_whisper /modeling_flax_whisper.py
supawichwac's picture
Saving train state of step 50
55f3766 verified
raw
history blame
92.2 kB
# coding=utf-8
# Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax whisper model."""
import random
from functools import partial
from typing import Dict, Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.linen.partitioning import remat, scan_with_axes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from transformers import WhisperConfig
from transformers.generation.flax_logits_process import (
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxWhisperTimeStampLogitsProcessor,
)
from transformers.modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
etc.) This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
and [`~FlaxPreTrainedModel.to_bf16`].
"""
WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
the starting token for `decoder_input_ids` generation.
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
spectrogram are ignored.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WHISPER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
encoder_outputs (`tuple(tuple(numpy.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit
processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens.
Args:
force_token_map (`list`):
Map giving token ids and indices where they will be forced to be sampled.
"""
def __init__(self, force_token_map):
# The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
# JAX type, and so we switch to using a JAX array instead
force_token_map = jnp.array(force_token_map)
# Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced. For XLA compatibility,
# indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
# Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
# dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
# so this is not permitted here.
force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
for index, token in force_token_map:
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
return new_scores
scores = lax.cond(
cur_len >= self.force_token_array.shape[0],
# If the current length is geq than the length of force_token_array, the processor does nothing.
lambda: scores,
# Otherwise, it may force a certain token.
lambda: lax.cond(
self.force_token_array[cur_len] >= 0,
# Only valid (positive) tokens are forced
lambda: _force_token(cur_len),
# Otherwise, the processor does nothing.
lambda: scores,
),
)
return scores
class FlaxWhisperAttention(nn.Module):
config: WhisperConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads (got `embed_dim`:"
f" {self.embed_dim} and `num_heads`: {self.num_heads})."
)
dense = partial(
DenseGeneral,
self.embed_dim,
axis=-1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "joined_kv"),
)
self.q_proj = dense(use_bias=self.bias)
self.k_proj = dense(use_bias=False)
self.v_proj = dense(use_bias=self.bias)
self.out_proj = DenseGeneral(
self.embed_dim,
axis=-1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("joined_kv", "embed"),
use_bias=self.bias,
)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_target_positions), dtype="bool"),
dtype="bool",
)
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
query_states = self.q_proj(hidden_states)
if is_cross_attention:
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv"))
key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv"))
value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv"))
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
# max_length of cached_key is last dim
max_decoder_length = self.variables["cache"]["cached_key"].shape[-1]
causal_mask = lax.dynamic_slice(
self.causal_mask,
(0, 0, mask_shift, 0),
(1, 1, query_length, max_decoder_length),
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
def _split_heads(self, hidden_state) -> jnp.ndarray:
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_state) -> jnp.ndarray:
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
# The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
# but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
def swap_dims(x):
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
# Sanity shape check of cached key against input query.
num_updated_cache_vectors = query.shape[1]
expected_shape = (batch_size, 1, num_heads, head_dim)
if num_updated_cache_vectors == 1 and expected_shape != query.shape:
raise ValueError(
"Autoregressive cache shape error, expected query shape"
f" {expected_shape} instead got {query.shape}"
)
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
# In order to update the key, value caches with the current key and
# value, we move the seq_length axis to the back, similar to what we did for
# the cached ones above.
# Note these are currently the key and value of a single position, since
# we feed one position at a time.
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
# Update key, value caches with our new 1d spatial slices.
# We implement an efficient scatter into the cache via one-hot
# broadcast and addition.
if num_updated_cache_vectors > 1:
indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
key = cached_key.value + jnp.matmul(one_token_key, indices)
value = cached_value.value + jnp.matmul(one_token_value, indices)
else:
one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + num_updated_cache_vectors
# Move the keys and values back to their original shapes.
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
# causal mask for cached decoder self-attention: our single query position should only
# attend to those key positions that have already been generated and cached, not the
# remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
(batch_size,) + (1, num_updated_cache_vectors, seq_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
class FlaxWhisperEncoderLayer(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = DenseGeneral(
self.config.encoder_ffn_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "mlp"),
)
self.fc2 = DenseGeneral(
self.embed_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("mlp", "embed"),
)
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
) -> Tuple[jnp.ndarray]:
if self.use_scan:
hidden_states = hidden_states[0]
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
residual = hidden_states
layernorm_output = self.self_attn_layer_norm(hidden_states)
layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed"))
attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask)
attn_output = self.dropout_layer(attn_output, deterministic=deterministic)
attn_output = residual + attn_output
attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed"))
residual = attn_output
post_layer_norm = self.final_layer_norm(attn_output)
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
hidden_states = self.fc2(fc1_output)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if self.use_scan:
if all_hidden_states is not None:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (
outputs,
all_hidden_states,
)
return outputs
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
@nn.compact
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
FlaxWhisperEncoderCheckpointLayer = (
remat(
FlaxWhisperEncoderLayer,
static_argnums=(2, 3),
prevent_cse=not self.use_scan,
)
if self.gradient_checkpointing
else FlaxWhisperEncoderLayer
)
if self.use_scan:
if output_attentions:
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
# nicest behaviour for scan is to let the compiler figure out the correct shapes for the hidden states
# so we'll just pass an empty tuple as the carry initializer and hold on to the first hidden states for later
input_hidden_states = hidden_states
hidden_states = (hidden_states,)
hidden_states, all_hidden_states = scan_with_axes(
FlaxWhisperEncoderCheckpointLayer,
variable_axes={"params": 0, "cache": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
),
variable_carry="all_hidden_states",
length=self.config.encoder_layers,
)(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=True,
name="FlaxEncoderScanLayers",
)(
hidden_states,
attention_mask,
output_attentions,
deterministic,
all_hidden_states, # tuple intializer (or None if not using output_hidden_states)
)
# remove the scan dimension
hidden_states = hidden_states[0]
if output_hidden_states:
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
else:
for layer_idx in range(self.config.encoder_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.config.encoder_layerdrop): # skip the layer
layer_outputs = (None, None)
else:
layer_outputs = FlaxWhisperEncoderCheckpointLayer(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
name=str(layer_idx),
)(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class FlaxWhisperDecoderLayer(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
self.encoder_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
self.fc1 = DenseGeneral(
self.config.decoder_ffn_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "mlp"),
)
self.fc2 = DenseGeneral(
self.embed_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("mlp", "embed"),
)
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
) -> Tuple[jnp.ndarray]:
if self.use_scan:
hidden_states = hidden_states[0]
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
residual = hidden_states
layer_norm_output = self.self_attn_layer_norm(hidden_states)
layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed"))
# Self Attention
self_attn_output, self_attn_weights = self.self_attn(
hidden_states=layer_norm_output,
attention_mask=attention_mask,
init_cache=init_cache,
)
self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic)
self_attn_output = residual + self_attn_output
self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed"))
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = self_attn_output
encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output)
encoder_layer_norm_output = with_sharding_constraint(
encoder_layer_norm_output, ("batch", "length", "embed")
)
cross_attn_output, cross_attn_weights = self.encoder_attn(
hidden_states=encoder_layer_norm_output,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic)
cross_attn_output = residual + cross_attn_output
cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed"))
# Fully Connected
residual = cross_attn_output
post_layer_norm = self.final_layer_norm(cross_attn_output)
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
hidden_states = self.fc2(fc1_output)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if self.use_scan:
if all_hidden_states is not None:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (
outputs,
all_hidden_states,
)
return outputs
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
@nn.compact
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
FlaxWhisperDecoderCheckpointLayer = (
remat(
FlaxWhisperDecoderLayer,
static_argnums=(4, 5, 6),
prevent_cse=not self.use_scan,
)
if self.gradient_checkpointing
else FlaxWhisperDecoderLayer
)
if self.use_scan:
if output_attentions:
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
input_hidden_states = hidden_states
hidden_states = (hidden_states,)
hidden_states, all_hidden_states = scan_with_axes(
FlaxWhisperDecoderCheckpointLayer,
variable_axes={"params": 0, "cache": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
),
variable_carry="all_hidden_states",
length=self.config.decoder_layers,
)(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=True,
name="FlaxDecoderScanLayers",
)(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
all_hidden_states,
)
hidden_states = hidden_states[0]
if output_hidden_states:
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
else:
for layer_idx in range(self.config.decoder_layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = FlaxWhisperDecoderCheckpointLayer(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
name=str(layer_idx),
)(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [
hidden_states,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
def setup(self) -> None:
self.conv1 = Conv(
self.config.d_model,
kernel_size=(3,),
padding=1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("channels", "num_mel", "embed"),
)
self.conv2 = Conv(
self.config.d_model,
kernel_size=(3,),
strides=2,
padding=1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("channels", "embed", "num_mel"),
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = Embed(
self.config.max_source_positions,
self.config.d_model,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
def __call__(
self,
input_features: jnp.ndarray,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
if input_features.shape[1:] != (
self.config.num_mel_bins,
self.config.max_source_positions * 2,
):
raise ValueError(
"input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
" self.config.max_source_positions * 2) (got"
f" {input_features.shape[1:]}, but should be"
f" ({self.config.num_mel_bins},"
f" {self.config.max_source_positions * 2}))"
)
input_features = input_features.transpose(0, 2, 1)
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
# sinusoidal positional embeddings should not be trained
embed_positions = jax.lax.stop_gradient(embed_positions)
hidden_states = hidden_states + embed_positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask=None,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
if self.use_scan:
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
else:
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
def setup(self) -> None:
self.embed_tokens = Embed(
self.config.vocab_size,
self.config.d_model,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.embed_positions = Embed(
self.config.max_target_positions,
self.config.d_model,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.layers = FlaxWhisperDecoderLayerCollection(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: jnp.ndarray,
position_ids: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
input_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
hidden_states = input_embeds + position_embeds
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
if self.use_scan:
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
else:
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.decoder = FlaxWhisperDecoder(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
decoder_attention_mask: jnp.ndarray,
decoder_position_ids: jnp.ndarray,
output_attentions: bool = False,
output_hidden_states: bool = False,
freeze_encoder: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
encoder_hidden_states = encoder_outputs[0]
if freeze_encoder:
encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
config_class = WhisperConfig
base_model_prefix: str = "model"
main_input_name = "input_features"
module_class: nn.Module = None
def __init__(
self,
config: WhisperConfig,
input_shape: Tuple[int, int, int] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
params_dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
# Can only use_scan=True in init if loading scanned weights -> need to handle use_scan=True and unrolled weights
use_scan: bool = False,
gradient_checkpointing: bool = False,
**kwargs,
):
self.use_scan = use_scan
self.gradient_checkpointing = gradient_checkpointing
module = self.module_class(
config=config,
dtype=dtype,
params_dtype=params_dtype,
use_scan=use_scan,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
if input_shape is None:
input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
super().__init__(
config,
module,
input_shape=input_shape,
seed=seed,
dtype=dtype,
_do_init=_do_init,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
def enable_scan(self):
self.use_scan = True
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
init_fn = partial(self.init_weights, input_shape=self.input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
# get the shape of the parameters
self._params_shape_tree = params_shape_tree
# save required_params as set
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
# initialize the parameters
if self._is_initialized:
self.params = self.convert_unroll_to_scan(self.params)
def disable_scan(self):
self.use_scan = False
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
init_fn = partial(self.init_weights, input_shape=self.input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
# get the shape of the parameters
self._params_shape_tree = params_shape_tree
# save required_params as set
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
# initialize the parameters
if self._is_initialized:
self.params = self.convert_scan_to_unroll(self.params)
def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
r"""
Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
convert the `params` in place.
To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
projection params is as follows:
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
_do_init=False, it will have to be called explicitly (see example below).
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
Examples:
```python
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
>>> # we'll first convert to scan format and then back to unrolled
>>> model.enable_scan()
>>> params = model.convert_unroll_to_scan(params)
>>> # now convert back to unrolled
>>> model.disable_scan()
>>> params = model.convert_scan_to_unroll(params)
```"""
if isinstance(params, FrozenDict):
params = unfreeze(params)
params = flatten_dict(params, sep="/")
keys = list(params.keys())
for k in keys:
# Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
# These params contain the identifier `layer` in their key
if "layers/0" in k:
if "decoder" in k:
block_prefix = "Decoder"
num_hidden_layers = self.config.decoder_layers
else:
block_prefix = "Encoder"
num_hidden_layers = self.config.encoder_layers
# Squash the keys for the N unrolled layers into one single key:
# (layer/0, ..., layer/N) -> layer/FlaxScanLayers
scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
stacked_params = []
# Iterate over the unrolled layers (1,...,N)
for i in range(num_hidden_layers):
# Stack the params for the N layers into one super block
# and remove the unrolled layer params on the fly
# -> no memory overhead for conversion!
unrolled_layer = params.pop(k.replace("0", str(i)))
stacked_params.append(unrolled_layer)
params[scan_key] = jnp.stack(stacked_params)
# Finally, unflatten the dict to restore the nested pytree structure
params = unflatten_dict(params, sep="/")
return params
def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
r"""
Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
not convert the `params` in place.
To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
scanned matrix of parameteres on the fly:
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
_do_init=False, it will have to be called explicitly (see example below).
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
Examples:
```python
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
>>> # we'll first convert to scan format and then back to unrolled
>>> model.enable_scan()
>>> params = model.convert_unroll_to_scan(params)
>>> # now convert back to unrolled
>>> model.disable_scan()
>>> params = model.convert_scan_to_unroll(params)
```"""
if isinstance(params, FrozenDict):
params = unfreeze(params)
params = flatten_dict(params, sep="/")
keys = list(params.keys())
for k in keys:
# Identify all "scan" layers formed as part of the FlaxBertLayerCollection
# These params contain the identifier `FlaxScanLayers` in their key
if "FlaxEncoderScanLayers" in k:
# Remove the scan layer from the PyTree of params
scan_layer = params.pop(k)
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
for i in range(self.config.encoder_layers):
# Unstack the params for the i-th scan layer to unrolled
# and remove corresponding scan params on the fly
# -> no memory overhead for conversion!
unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
elif "FlaxDecoderScanLayers" in k:
# Remove the scan layer from the PyTree of params
scan_layer = params.pop(k)
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
for i in range(self.config.decoder_layers):
# Unstack the params for the i-th scan layer to unrolled
# and remove corresponding scan params on the fly
# -> no memory overhead for conversion!
unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
params = unflatten_dict(params, sep="/")
return params
# Copied from transformers.models.whisper.modeling_flax_whisper.FlaxWhisperPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
decoder_input_ids.shape,
)
def _decoder_forward(
module,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
def encode(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_features, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_features, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=FlaxBaseModelOutputWithPastAndCrossAttentions,
config_class=WhisperConfig,
)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
batch_size, sequence_length = decoder_input_ids.shape
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxWhisperAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(
module,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_encoder: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare decoder inputs
if decoder_position_ids is None:
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_encoder=freeze_encoder,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
("The bare Whisper Model transformer outputting raw hidden-states without any specific head on top."),
WHISPER_START_DOCSTRING,
)
class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
params_dtype: jnp.dtype = jnp.float32
module_class = FlaxWhisperModule
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
params_dtype: jnp.dtype = jnp.float32
use_scan: bool = False
gradient_checkpointing: bool = False
def setup(self) -> None:
self.model = FlaxWhisperModule(
config=self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.lm_head = DenseGeneral(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "vocab"),
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_features,
decoder_input_ids,
decoder_attention_mask: jnp.ndarray = None,
decoder_position_ids: jnp.ndarray = None,
position_ids: jnp.ndarray = None,
attention_mask: jnp.ndarray = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
freeze_encoder: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_encoder=freeze_encoder,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForConditionalGenerationModule
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
batch_size, sequence_length = decoder_input_ids.shape
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxWhisperAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(
module,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
):
decoder_module = module._get_decoder_module()
outputs = decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
return lm_logits, outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def generate(
self,
input_features,
generation_config=None,
logits_processor=None,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs,
):
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
if language is not None:
generation_config.language = language
if kwargs is not None and "decoder_input_ids" in kwargs:
decoder_input_length = len(kwargs["decoder_input_ids"])
else:
decoder_input_length = 1
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
forced_decoder_ids.append((1, None))
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
return super().generate(
input_features,
generation_config,
logits_processor=logits_processor,
**kwargs,
)
def pipeline_generate(
self,
input_features,
forced_decoder_ids,
return_timestamps=False,
generation_config=None,
**kwargs,
):
if generation_config is None:
generation_config = self.generation_config
# override the generation config forced decoder ids in preference of the ones we have set
generation_config.forced_decoder_ids = None
logits_processor = FlaxLogitsProcessorList()
logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids))
if hasattr(generation_config, "return_timestamps") and return_timestamps:
logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1))
return super().generate(
input_features,
generation_config,
logits_processor=logits_processor,
**kwargs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
Returns:
Transcription example:
```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
overwrite_call_docstring(
FlaxWhisperForConditionalGeneration,
WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration,
output_type=FlaxSeq2SeqLMOutput,
config_class=_CONFIG_FOR_DOC,
)