|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Classes to support Flax Speech-Encoder-Decoder architectures""" |
|
|
|
import os |
|
from functools import partial |
|
from typing import Optional, Tuple, Union, Dict |
|
|
|
import flax |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
import numpy as np |
|
|
|
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput |
|
from transformers.modeling_flax_utils import FlaxPreTrainedModel |
|
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput |
|
from transformers.generation_flax_utils import FlaxLogitsProcessorList |
|
from models import ( |
|
FlaxWav2Vec2Model, |
|
FlaxWav2Vec2Module, |
|
FlaxBartForCausalLM, |
|
FlaxBartForCausalLMModule, |
|
BartConfig, |
|
Wav2Vec2Config, |
|
SpeechEncoderDecoderConfig, |
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" |
|
|
|
SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" |
|
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech |
|
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is |
|
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via |
|
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder |
|
and should be fine-tuned on a downstream generative task, like summarization. |
|
|
|
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation |
|
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation |
|
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi |
|
Zhou, Wei Li, Peter J. Liu. |
|
|
|
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech |
|
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech |
|
translation yields a significant performance improvement. |
|
|
|
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other |
|
models (see the examples for more information). |
|
|
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a Flax Linen |
|
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
|
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
|
|
|
Parameters: |
|
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
|
`jax.numpy.bfloat16` (on TPUs). |
|
|
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
|
specified all the computation will be performed with the given `dtype`. |
|
|
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model |
|
parameters.** |
|
|
|
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
|
[`~FlaxPreTrainedModel.to_bf16`]. |
|
""" |
|
|
|
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" |
|
Args: |
|
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
|
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
|
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
|
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
|
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
|
*torch.FloatTensor*. |
|
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Indices of decoder input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
|
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
|
and prepending them with the `decoder_start_token_id`. |
|
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. |
|
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
|
range `[0, config.decoder.max_position_embeddings - 1]`. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. |
|
""" |
|
|
|
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" |
|
Args: |
|
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
|
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
|
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
|
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
|
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
|
*torch.FloatTensor*. |
|
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" |
|
Args: |
|
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Indices of decoder input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids) |
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
|
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
|
and prepending them with the `decoder_start_token_id`. |
|
encoder_outputs (`tuple(tuple(jnp.ndarray)`): |
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) |
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of |
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
|
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. |
|
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
|
range `[0, config.decoder.max_position_embeddings - 1]`. |
|
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): |
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a |
|
plain tuple. |
|
""" |
|
|
|
@flax.struct.dataclass |
|
class FlaxBeamSearchOutput(ModelOutput): |
|
""" |
|
Flax Base class for outputs of decoder-only generation models using greedy search. |
|
|
|
|
|
Args: |
|
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): |
|
The generated sequences. |
|
scores (`jnp.ndarray` of shape `(batch_size,)`): |
|
The scores (log probabilites) of the generated sequences. |
|
""" |
|
|
|
sequences: jnp.ndarray = None |
|
scores: jnp.ndarray = None |
|
|
|
|
|
@flax.struct.dataclass |
|
class BeamSearchState: |
|
cur_len: jnp.ndarray |
|
running_sequences: jnp.ndarray |
|
running_scores: jnp.ndarray |
|
sequences: jnp.ndarray |
|
scores: jnp.ndarray |
|
is_sent_finished: jnp.ndarray |
|
model_kwargs: Dict[str, jnp.ndarray] |
|
|
|
|
|
|
|
|
|
class FlaxSpeechEncoderDecoderModule(nn.Module): |
|
config: SpeechEncoderDecoderConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
encoder_config = self.config.encoder |
|
decoder_config = self.config.decoder |
|
|
|
|
|
encoder_module = FlaxWav2Vec2Module |
|
decoder_module = FlaxBartForCausalLMModule |
|
|
|
self.encoder = encoder_module(encoder_config, dtype=self.dtype) |
|
self.decoder = decoder_module(decoder_config, dtype=self.dtype) |
|
|
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
self.enc_to_dec_proj = nn.Dense( |
|
self.decoder.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
else: |
|
self.enc_to_dec_proj = None |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
|
): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return (input_length - kernel_size) // stride + 1 |
|
|
|
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
if add_adapter: |
|
for _ in range(self.config.encoder.num_adapter_layers): |
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) |
|
|
|
return input_lengths |
|
|
|
def _get_encoder_module(self): |
|
return self.encoder |
|
|
|
def _get_projection_module(self): |
|
return self.enc_to_dec_proj |
|
|
|
def _get_decoder_module(self): |
|
return self.decoder |
|
|
|
def __call__( |
|
self, |
|
inputs, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
encoder_outputs=None, |
|
extract_features=None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
output_features: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
freeze_feature_encoder: bool = False, |
|
): |
|
if encoder_outputs is None: |
|
encoder_outputs = self.encoder( |
|
inputs, |
|
attention_mask=attention_mask, |
|
extract_features=extract_features, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_features=output_features, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
freeze_feature_encoder=freeze_feature_encoder, |
|
) |
|
|
|
if output_features: |
|
return encoder_outputs |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
|
|
if self.enc_to_dec_proj is not None: |
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
|
|
|
if attention_mask is not None: |
|
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( |
|
encoder_hidden_states.shape[1], attention_mask |
|
) |
|
else: |
|
encoder_attention_mask = None |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return FlaxSeq2SeqLMOutput( |
|
logits=decoder_outputs.logits, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_hidden_states, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) |
|
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): |
|
r""" |
|
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture |
|
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one |
|
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the |
|
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. |
|
""" |
|
|
|
config_class = SpeechEncoderDecoderConfig |
|
base_model_prefix: str = "speech_encoder_decoder" |
|
module_class = FlaxSpeechEncoderDecoderModule |
|
|
|
def __init__( |
|
self, |
|
config: SpeechEncoderDecoderConfig, |
|
input_shape: Optional[Tuple] = None, |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs |
|
): |
|
|
|
if not _do_init: |
|
raise ValueError( |
|
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." |
|
) |
|
|
|
if config.decoder.cross_attention_hidden_size is not None: |
|
|
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: |
|
raise ValueError( |
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, " |
|
"it has to be equal to the encoder's `hidden_size`. " |
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " |
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." |
|
) |
|
|
|
|
|
config.tie_word_embeddings = False |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
|
|
if input_shape is None: |
|
|
|
encoder_input_length = 1024 |
|
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) |
|
input_shape = ((1, encoder_input_length), (1, decoder_input_length)) |
|
|
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
encoder_input_shape, decoder_input_shape = input_shape |
|
|
|
|
|
inputs = jnp.zeros(encoder_input_shape, dtype="f4") |
|
attention_mask = jnp.ones_like(inputs, dtype="i4") |
|
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
batch_size, sequence_length = inputs.shape |
|
|
|
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape |
|
if not decoder_batch_size == batch_size: |
|
raise ValueError( |
|
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." |
|
) |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) |
|
) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
return self.module.init( |
|
rngs, |
|
inputs, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
)["params"] |
|
|
|
def init_cache(self, batch_size, max_length, encoder_outputs): |
|
r""" |
|
Args: |
|
batch_size (`int`): |
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
|
max_length (`int`): |
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
|
cache. |
|
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
|
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: |
|
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) |
|
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the |
|
cross-attention of the decoder. |
|
""" |
|
|
|
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape |
|
) |
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
**kwargs, |
|
) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
decoder_position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
init_cache=True, |
|
method=_decoder_forward, |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
|
): |
|
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) |
|
|
|
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) |
|
def encode( |
|
self, |
|
inputs: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
extract_features: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_features: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
freeze_feature_encoder: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import FlaxSpeechEncoderDecoderModel |
|
|
|
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
|
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
|
... ) |
|
|
|
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
|
>>> encoder_outputs = model.encode(inputs) |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(inputs, dtype="i4") |
|
|
|
if extract_features is not None: |
|
extract_features = jnp.array(extract_features, dtype="f4") |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _encoder_forward(module, inputs, attention_mask, **kwargs): |
|
encode_module = module._get_encoder_module() |
|
return encode_module(inputs, attention_mask, **kwargs) |
|
|
|
outputs = self.module.apply( |
|
{"params": params or self.params}, |
|
inputs=jnp.array(inputs, dtype="f4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
extract_features=extract_features, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_features=output_features, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
freeze_feature_encoder=freeze_feature_encoder, |
|
rngs=rngs, |
|
method=_encoder_forward, |
|
) |
|
|
|
if return_dict and not output_features: |
|
outputs = FlaxBaseModelOutput( |
|
last_hidden_state=outputs.last_hidden_state, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
return outputs |
|
|
|
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) |
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import FlaxSpeechEncoderDecoderModel |
|
>>> import jax.numpy as jnp |
|
|
|
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
|
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
|
... ) |
|
|
|
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
|
>>> encoder_outputs = model.encode(inputs) |
|
|
|
>>> decoder_start_token_id = model.config.decoder.bos_token_id |
|
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id |
|
|
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs) |
|
>>> logits = outputs.logits |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") |
|
|
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
params = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
params["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs |
|
): |
|
|
|
projection_module = module._get_projection_module() |
|
decoder_module = module._get_decoder_module() |
|
|
|
|
|
if projection_module is not None: |
|
encoder_hidden_states = projection_module(encoder_hidden_states) |
|
|
|
return decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
encoder_hidden_states, |
|
**kwargs, |
|
) |
|
|
|
outputs = self.module.apply( |
|
params, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past = outputs |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past = outputs |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) |
|
def __call__( |
|
self, |
|
inputs: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
extract_features: Optional[jnp.ndarray] = None, |
|
decoder_input_ids: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_features: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
freeze_feature_encoder: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer |
|
|
|
>>> # load a fine-tuned wav2vec2-2-bart model |
|
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") |
|
>>> # load output tokenizer |
|
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large") |
|
|
|
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
|
|
|
>>> # use bart's special bos, pad and eos tokens |
|
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id |
|
>>> model.config.pad_token_id = model.decoder.config.pad_token_id |
|
>>> model.config.eos_token_id = model.decoder.config.eos_token_id |
|
|
|
>>> outputs = model.generate(inputs) |
|
# Assert something? More interesting input? dtype correct? |
|
``` |
|
""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(inputs, dtype="i4") |
|
|
|
if extract_features is not None: |
|
inputs = None |
|
extract_features = jnp.array(extract_features, dtype="f4") |
|
else: |
|
inputs = jnp.array(inputs, dtype="f4") |
|
|
|
|
|
if decoder_input_ids is None: |
|
raise ValueError( |
|
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument." |
|
) |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
if decoder_position_ids is None: |
|
batch_size, sequence_length = decoder_input_ids.shape |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
inputs=inputs, |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
extract_features=extract_features, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_features=output_features, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
freeze_feature_encoder=freeze_feature_encoder, |
|
rngs=rngs, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
max_length, |
|
attention_mask: Optional[jnp.DeviceArray] = None, |
|
decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
|
encoder_outputs=None, |
|
**kwargs |
|
): |
|
|
|
batch_size, seq_length = decoder_input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if decoder_attention_mask is not None: |
|
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) |
|
else: |
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) |
|
) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"encoder_outputs": encoder_outputs, |
|
"encoder_attention_mask": attention_mask, |
|
"decoder_attention_mask": extended_attention_mask, |
|
"decoder_position_ids": decoder_position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 |
|
return model_kwargs |
|
|
|
@classmethod |
|
def from_encoder_decoder_pretrained( |
|
cls, |
|
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
|
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
|
*model_args, |
|
**kwargs |
|
) -> FlaxPreTrainedModel: |
|
r""" |
|
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model |
|
checkpoints. |
|
|
|
Params: |
|
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): |
|
Information necessary to initiate the encoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
|
user or organization name, like `dbmdz/bert-base-german-cased`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
|
|
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): |
|
Information necessary to initiate the decoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
|
user or organization name, like `dbmdz/bert-base-german-cased`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
|
|
model_args (remaining positional arguments, *optional*): |
|
All remaning positional arguments will be passed to the underlying model's `__init__` method. |
|
|
|
kwargs (remaining dictionary of keyword arguments, *optional*): |
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., |
|
`output_attentions=True`). |
|
|
|
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. |
|
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. |
|
- To update the parent model configuration, do not use a prefix for each configuration parameter. |
|
|
|
Behaves differently depending on whether a `config` is provided or automatically loaded. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import FlaxSpeechEncoderDecoderModel |
|
|
|
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
|
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
|
... ) |
|
>>> # saving model after fine-tuning |
|
>>> model.save_pretrained("./wav2vec2-2-bart-large") |
|
>>> # load fine-tuned model |
|
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") |
|
```""" |
|
|
|
kwargs_encoder = { |
|
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") |
|
} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
|
|
for key in kwargs_encoder.keys(): |
|
del kwargs["encoder_" + key] |
|
for key in kwargs_decoder.keys(): |
|
del kwargs["decoder_" + key] |
|
|
|
|
|
|
|
|
|
encoder = kwargs_encoder.pop("model", None) |
|
if encoder is None: |
|
if encoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_encoder: |
|
|
|
encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained( |
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True |
|
) |
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
|
logger.info( |
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " |
|
"from a decoder model. Cross-attention and casual mask are disabled." |
|
) |
|
encoder_config.is_decoder = False |
|
encoder_config.add_cross_attention = False |
|
|
|
kwargs_encoder["config"] = encoder_config |
|
|
|
|
|
encoder = FlaxWav2Vec2Model.from_pretrained( |
|
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder |
|
) |
|
|
|
decoder = kwargs_decoder.pop("model", None) |
|
if decoder is None: |
|
if decoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_decoder: |
|
|
|
decoder_config, kwargs_decoder = BartConfig.from_pretrained( |
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
|
) |
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
|
logger.info( |
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " |
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " |
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " |
|
"cross attention layers." |
|
) |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
kwargs_decoder["config"] = decoder_config |
|
|
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: |
|
logger.warning( |
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " |
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " |
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " |
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " |
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`" |
|
) |
|
|
|
|
|
decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
|
|
|
|
|
dtype = kwargs.pop("dtype", jnp.float32) |
|
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) |
|
|
|
|
|
config.tie_word_embeddings = False |
|
|
|
|
|
model = cls(config, dtype=dtype) |
|
model.params["encoder"] = encoder.params |
|
model.params["decoder"] = decoder.params |
|
|
|
return model |
|
|
|
def _beam_search( |
|
self, |
|
input_ids: None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
length_penalty: Optional[float] = None, |
|
early_stopping: Optional[bool] = None, |
|
logits_processor: Optional[FlaxLogitsProcessorList] = None, |
|
trace: bool = True, |
|
params: Optional[Dict[str, jnp.ndarray]] = None, |
|
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, |
|
): |
|
""" |
|
This beam search function is heavily inspired by Flax's official example: |
|
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 |
|
""" |
|
|
|
def flatten_beam_dim(tensor): |
|
"""Flattens the first two dimensions of a non-scalar array.""" |
|
|
|
if tensor.ndim == 0 or tensor.ndim == 1: |
|
return tensor |
|
elif tensor.ndim == 6: |
|
return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:]) |
|
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) |
|
|
|
def unflatten_beam_dim(tensor, batch_size, num_beams): |
|
"""Unflattens the first, flat batch*beam dimension of a non-scalar array.""" |
|
|
|
if tensor.ndim == 0 or tensor.ndim == 1: |
|
return tensor |
|
if tensor.ndim == 5: |
|
return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:]) |
|
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) |
|
|
|
def gather_beams(nested, beam_indices, batch_size, new_num_beams): |
|
""" |
|
Gathers the beam slices indexed by beam_indices into new beam array. |
|
""" |
|
batch_indices = jnp.reshape( |
|
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) |
|
) |
|
|
|
def gather_fn(tensor): |
|
|
|
if tensor.ndim == 0 or tensor.ndim == 1: |
|
return tensor |
|
if tensor.ndim == 6: |
|
return tensor[:, batch_indices, beam_indices] |
|
return tensor[batch_indices, beam_indices] |
|
|
|
return jax.tree_map(gather_fn, nested) |
|
|
|
|
|
max_length = max_length if max_length is not None else self.config.max_length |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty |
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping |
|
|
|
batch_size, num_beams, cur_len = input_ids.shape |
|
|
|
eos_token_id = jnp.array(eos_token_id) |
|
pad_token_id = jnp.array(pad_token_id) |
|
cur_len = jnp.array(cur_len) |
|
|
|
|
|
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
|
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
|
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) |
|
|
|
|
|
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) |
|
|
|
|
|
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) |
|
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) |
|
|
|
|
|
|
|
model = self.decode if self.config.is_encoder_decoder else self |
|
|
|
|
|
if "encoder_outputs" in model_kwargs: |
|
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( |
|
model_kwargs["encoder_outputs"]["last_hidden_state"] |
|
) |
|
if "attention_mask" in model_kwargs: |
|
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) |
|
|
|
|
|
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) |
|
|
|
|
|
state = BeamSearchState( |
|
cur_len=cur_len, |
|
running_sequences=running_sequences, |
|
running_scores=running_scores, |
|
sequences=sequences, |
|
scores=scores, |
|
is_sent_finished=is_sent_finished, |
|
model_kwargs=model_kwargs, |
|
) |
|
|
|
def beam_search_cond_fn(state): |
|
"""beam search state termination condition fn.""" |
|
|
|
|
|
not_max_length_yet = state.cur_len < max_length |
|
|
|
|
|
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) |
|
worst_finished_score = jnp.where( |
|
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) |
|
) |
|
improvement_still_possible = jnp.all(worst_finished_score < best_running_score) |
|
|
|
|
|
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) |
|
|
|
return not_max_length_yet & still_open_beam & improvement_still_possible |
|
|
|
def beam_search_body_fn(state, input_ids_length=1): |
|
"""beam search state update fn.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_token = flatten_beam_dim( |
|
lax.dynamic_slice( |
|
state.running_sequences, |
|
(0, 0, state.cur_len - input_ids_length), |
|
(batch_size, num_beams, input_ids_length), |
|
) |
|
) |
|
model_outputs = model(input_token, params=params, **state.model_kwargs) |
|
|
|
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) |
|
cache = jax.tree_map( |
|
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values |
|
) |
|
|
|
|
|
logits = self._adapt_logits_for_beam_search(logits) |
|
|
|
|
|
|
|
|
|
|
|
log_probs = jax.nn.log_softmax(logits) |
|
log_probs = logits_processor( |
|
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len |
|
) |
|
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) |
|
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) |
|
vocab_size = log_probs.shape[2] |
|
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beams_to_keep = 2 * num_beams |
|
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) |
|
topk_beam_indices = topk_indices // vocab_size |
|
topk_running_sequences = gather_beams( |
|
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep |
|
) |
|
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) |
|
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id |
|
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) |
|
|
|
|
|
|
|
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) |
|
next_running_sequences, next_running_scores = gather_beams( |
|
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) |
|
beams_in_batch_are_full = ( |
|
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) |
|
& early_stopping |
|
) |
|
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full |
|
topk_log_probs += add_penalty * np.array(-1.0e7) |
|
|
|
|
|
|
|
|
|
|
|
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) |
|
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) |
|
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) |
|
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) |
|
next_sequences, next_scores, next_is_sent_finished = gather_beams( |
|
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams |
|
) |
|
|
|
|
|
|
|
|
|
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) |
|
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) |
|
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) |
|
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) |
|
|
|
return BeamSearchState( |
|
cur_len=state.cur_len + 1, |
|
running_scores=next_running_scores, |
|
running_sequences=next_running_sequences, |
|
scores=next_scores, |
|
sequences=next_sequences, |
|
is_sent_finished=next_is_sent_finished, |
|
model_kwargs=next_model_kwargs, |
|
) |
|
|
|
|
|
if input_ids.shape[-1] > 1: |
|
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) |
|
|
|
if not trace: |
|
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) |
|
else: |
|
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) |
|
|
|
|
|
|
|
none_finished = jnp.any(state.is_sent_finished, axis=1) |
|
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) |
|
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) |
|
|
|
|
|
sequences = sequences[:, :] |
|
scores = scores[:, -1] |
|
|
|
return FlaxBeamSearchOutput(sequences=sequences, scores=scores) |
|
|