# coding=utf-8 # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Flax Bart model. """ import math from functools import partial from typing import Callable, Optional, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey from transformers.modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) from transformers.modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, ) from transformers.utils import logging from .configuration_bart import BartConfig logger = logging.get_logger(__name__) def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: """ Shift input ids one token to the right. """ shifted_input_ids = np.zeros_like(input_ids) shifted_input_ids[:, 1:] = input_ids[:, :-1] shifted_input_ids[:, 0] = decoder_start_token_id shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids class FlaxBartAttention(nn.Module): config: BartConfig embed_dim: int num_heads: int dropout: float = 0.0 causal: bool = False bias: bool = True dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self) -> None: self.head_dim = self.embed_dim // self.num_heads assert ( self.head_dim * self.num_heads == self.embed_dim ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." dense = partial( nn.Dense, self.embed_dim, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.dropout_layer = nn.Dropout(rate=self.dropout) if self.causal: self.causal_mask = make_causal_mask( jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" ) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) @nn.compact def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) ) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) # combine masks if needed if self.causal: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) else: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) # Convert the boolean attention mask to an attention bias. # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), ) dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output class FlaxBartEncoderLayer(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.embed_dim = self.config.d_model self.self_attn = FlaxBartAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( self.embed_dim, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states hidden_states = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) return hidden_states class FlaxBartEncoderLayerCollection(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): layer_module = nn.remat(FlaxBartEncoderLayer) if self.config.gradient_checkpointing else FlaxBartEncoderLayer self.layers = [ layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) ] def __call__( self, hidden_states, attention_mask, deterministic: bool = True, ): for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, deterministic, ) return FlaxBaseModelOutput(last_hidden_state=hidden_states) class FlaxBartDecoderLayer(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.embed_dim = self.config.d_model self.self_attn = FlaxBartAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn = FlaxBartAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( self.embed_dim, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, encoder_hidden_states: jnp.ndarray, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block residual = hidden_states hidden_states = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) return hidden_states class FlaxBartDecoderLayerCollection(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): layer_module = nn.remat(FlaxBartDecoderLayer) if self.config.gradient_checkpointing else FlaxBartDecoderLayer self.layers = [ layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) ] def __call__( self, hidden_states, attention_mask, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, init_cache: bool = False, ): # decoder layers for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, init_cache=init_cache, deterministic=deterministic, ) return FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states) class FlaxBartEncoder(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.max_source_positions = self.config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 0 self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) def __call__( self, input_ids, attention_mask, position_ids, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = inputs_embeds.astype(self.dtype) embed_pos = self.embed_positions(position_ids + self.offset) embed_pos = embed_pos.astype(self.dtype) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers(hidden_states, attention_mask, deterministic=deterministic) return FlaxBaseModelOutput( last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxBartDecoder(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_position_embeddings self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.decoder_vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 0 self.embed_positions = nn.Embed( self.config.decoder_max_position_embeddings + self.offset, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) def __call__( self, input_ids, attention_mask, position_ids, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = inputs_embeds.astype(self.dtype) # embed positions positions = self.embed_positions(position_ids + self.offset) positions = positions.astype(self.dtype) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, deterministic=deterministic, init_cache=init_cache, ) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) class FlaxBartModule(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype) self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype) def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def __call__( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, return_dict: bool = True, deterministic: bool = True, ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, deterministic=deterministic, ) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, deterministic=deterministic, ) if not return_dict: return decoder_outputs + encoder_outputs return FlaxSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxBartPreTrainedModel(FlaxPreTrainedModel): config_class = BartConfig base_model_prefix: str = "model" module_class: nn.Module = None def __init__( self, config: BartConfig, input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): module = self.module_class(config=config, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBartForSequenceClassificationModule input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (:obj:`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (:obj:`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): ``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) decoder_position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape ) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example:: >>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration >>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax') >>> encoder_outputs = model.encode(**inputs) """ if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): encode_module = module._get_encoder_module() return encode_module(input_ids, attention_mask, position_ids, **kwargs) return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), deterministic=not train, rngs=rngs, method=_encoder_forward, ) def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): return_dict = return_dict if return_dict is not None else self.config.return_dict # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # prepare decoder inputs if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id ) if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), deterministic=not train, rngs=rngs, ) class FlaxBartForConditionalGenerationModule(nn.Module): config: BartConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.model = FlaxBartModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.decoder_vocab_size, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def _get_encoder_module(self): return self.model.encoder def _get_decoder_module(self): return self.model.decoder def __call__( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, deterministic: bool = True, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, position_ids=position_ids, decoder_position_ids=decoder_position_ids, deterministic=deterministic, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.model.variables["params"]["shared"]["embedding"] lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) return FlaxSeq2SeqLMOutput( logits=lm_logits, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): module_class = FlaxBartForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example:: >>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration >>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax') >>> encoder_outputs = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits """ encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() outputs = decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = module.model.variables["params"]["shared"]["embedding"] lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) else: lm_logits = module.lm_head(hidden_states) lm_logits += module.final_logits_bias return lm_logits, outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) # add updated cache to model output if past_key_values is not None: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs return outputs def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, "decoder_position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 return model_kwargs