| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Layers and Modules for Knowledge-FID.""" |
|
|
| import functools |
| from typing import Optional, Sequence |
|
|
| from flax import linen as nn |
| import jax |
| import jax.numpy as jnp |
| from scenic.projects.knowledge_visual_language.models import constants |
| from t5x.examples.t5 import layers as t5_layers |
| from t5x.examples.t5 import network as t5_network |
|
|
|
|
| @jax.vmap |
| def batch_index_select(data, idx): |
| return jnp.take(data, idx, axis=0) |
|
|
|
|
| def _mask_select(data, mask): |
| return jax.lax.select( |
| mask > 0, data, jnp.full(data.shape, 0).astype(data.dtype) |
| ) |
|
|
|
|
| def l2_norm(x): |
| """Compute the l2 norm of a vector.""" |
| return jnp.sqrt((x * x).sum(axis=-1)) |
|
|
|
|
| def l2_normalize(x, axis=-1, eps=1e-10): |
| """Normalizes along dimension `axis` using an L2 norm. |
| |
| This specialized function exists for numerical stability reasons. |
| Args: |
| x: An input ndarray. |
| axis: Dimension along which to normalize, e.g. `1` to separately normalize |
| vectors in a batch. Passing `None` views `t` as a flattened vector when |
| calculating the norm (equivalent to Frobenius norm). |
| eps: Epsilon to avoid dividing by zero. |
| |
| Returns: |
| An array of the same shape as 'x' L2-normalized along 'axis'. |
| """ |
| denorm = (x * x).sum(axis=axis, keepdims=True) + eps |
| return (x * jax.lax.rsqrt(denorm)).astype(x.dtype) |
|
|
|
|
| class AffineTransform(nn.Module): |
| """Do affine Transform for modulating attention score.""" |
|
|
| @nn.compact |
| def __call__(self, x): |
| scale = self.param('scale', nn.initializers.ones, (1,), jnp.float32) |
| bias = self.param('bias', nn.initializers.zeros, (1,), jnp.float32) |
| return x * nn.sigmoid(scale) * 5 + bias |
|
|
|
|
| class TransformerHead(nn.Module): |
| """A stack of encoder layers.""" |
|
|
| num_head_layers: int |
| key_dim: int |
| vocab_size: int |
| emb_dim: int |
| num_heads: int |
| num_encoder_layers: int |
| num_decoder_layers: int |
| head_dim: int |
| mlp_dim: int |
| dropout_rate: float |
| out_head: nn.Module |
| dtype: str = 'bfloat16' |
| mlp_activations: Sequence[str] = ('gelu', 'linear') |
| logits_via_embedding: bool = False |
|
|
| def setup(self): |
| self.t5_config = t5_network.T5Config( |
| vocab_size=self.vocab_size, |
| emb_dim=self.emb_dim, |
| num_heads=self.num_heads, |
| num_encoder_layers=self.num_encoder_layers, |
| num_decoder_layers=self.num_decoder_layers, |
| head_dim=self.head_dim, |
| mlp_dim=self.mlp_dim, |
| dropout_rate=self.dropout_rate, |
| dtype=self.dtype, |
| mlp_activations=self.mlp_activations, |
| logits_via_embedding=self.logits_via_embedding, |
| ) |
|
|
| @nn.compact |
| def __call__(self, encoded_emb, encoder_mask=None, use_dropout=True): |
| """transform the encoded representation.""" |
| cfg = self.t5_config |
| assert encoded_emb.ndim == 3 |
| x = encoded_emb |
| if encoder_mask is not None: |
| encoder_mask = t5_layers.make_attention_mask( |
| encoder_mask, encoder_mask, dtype=cfg.dtype |
| ) |
|
|
| rel_emb = t5_layers.RelativePositionBiases( |
| num_buckets=32, |
| max_distance=128, |
| num_heads=self.num_heads, |
| dtype=self.dtype, |
| embedding_init=nn.initializers.variance_scaling( |
| 1.0, 'fan_avg', 'uniform' |
| ), |
| ) |
| for _ in range( |
| cfg.num_encoder_layers - self.num_head_layers, cfg.num_encoder_layers |
| ): |
| |
| x = t5_network.EncoderLayer(config=cfg, relative_embedding=rel_emb)( |
| x, encoder_mask, deterministic=not use_dropout |
| ) |
| x = t5_layers.LayerNorm(dtype=cfg.dtype)(x[:, 0, :]) |
| return l2_normalize(self.out_head(x), axis=-1) |
|
|
|
|
| class LowerT5Encoder(nn.Module): |
| """T5 encoder as a separate model which fuse multi-modal input. |
| |
| This module contains the encoder part of a pretrained T5. It is useful when |
| adopting the pretrained T5 encoder as a part of a larger network. Note that |
| the embedding layer should be created outside the module and provided as a |
| parameter `shared_embedding` to share it in other parts of the network (e.g., |
| text encoder). If `shared_embedding` is not provided, the embedding layer is |
| created within the module. |
| |
| Attributes: |
| vocab_size: Size of the vocabulary. |
| emb_dim: Size of the embeddings. |
| num_heads: Number of attention heads. |
| num_encoder_layers: Number of encoder layers. |
| num_decoder_layers: Number of decoder layers. |
| head_dim: Size of the embeddings in each head. |
| mlp_dim: Size of the MLP output embeddings. |
| dropout_rate: Dropout rate. |
| dtype: Data type. |
| mlp_activations: Sequence of activations in MLP. |
| logits_via_embedding: Use the embedding weights for computing logits. |
| shared_embedding: Optional. Embedding layer that is shared outside this |
| module. If not given, a non-shared embedding layer will be created within |
| the module. |
| """ |
|
|
| vocab_size: int |
| emb_dim: int |
| num_heads: int |
| num_encoder_layers: int |
| num_decoder_layers: int |
| num_fusion_layers: int |
| head_dim: int |
| mlp_dim: int |
| dropout_rate: float |
| dtype: str = 'bfloat16' |
| mlp_activations: Sequence[str] = ('gelu', 'linear') |
| logits_via_embedding: bool = False |
| shared_embedding: Optional[nn.Module] = None |
|
|
| def setup(self): |
| self.t5_config = t5_network.T5Config( |
| vocab_size=self.vocab_size, |
| emb_dim=self.emb_dim, |
| num_heads=self.num_heads, |
| num_encoder_layers=self.num_encoder_layers, |
| num_decoder_layers=self.num_decoder_layers, |
| head_dim=self.head_dim, |
| mlp_dim=self.mlp_dim, |
| dropout_rate=self.dropout_rate, |
| dtype=self.dtype, |
| mlp_activations=self.mlp_activations, |
| logits_via_embedding=self.logits_via_embedding, |
| ) |
| if self.shared_embedding is None: |
| self.shared_embedding = t5_layers.Embed( |
| num_embeddings=self.vocab_size, |
| features=self.emb_dim, |
| dtype=self.dtype, |
| attend_dtype=jnp.float32, |
| embedding_init=nn.initializers.normal(stddev=1.0), |
| one_hot=True, |
| ) |
|
|
| @nn.compact |
| def __call__( |
| self, |
| encoder_input_tokens, |
| encoder_segment_ids=None, |
| use_dropout=True, |
| frozen_base=True, |
| ): |
| """encode the text sentence only. |
| |
| Args: |
| encoder_input_tokens: input text tokens |
| encoder_segment_ids: segmend ID in packing mode |
| use_dropout: whether to use dropout during Training |
| frozen_base: whether froze the text encoder |
| |
| Returns: |
| Sequence of token embedding with or without fusion |
| """ |
| cfg = self.t5_config |
| assert encoder_input_tokens.ndim == 2 |
| |
| encoder_mask = encoder_input_tokens > 0 |
| mask_matrix = t5_layers.make_attention_mask( |
| encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype |
| ) |
| |
| if encoder_segment_ids is not None: |
| mask_matrix = t5_layers.combine_masks( |
| mask_matrix, |
| t5_layers.make_attention_mask( |
| encoder_segment_ids, |
| encoder_segment_ids, |
| jnp.equal, |
| dtype=cfg.dtype, |
| ), |
| ) |
|
|
| rel_emb = t5_layers.RelativePositionBiases( |
| num_buckets=32, |
| max_distance=128, |
| num_heads=self.t5_config.num_heads, |
| dtype=self.t5_config.dtype, |
| embedding_init=nn.initializers.variance_scaling( |
| 1.0, 'fan_avg', 'uniform' |
| ), |
| ) |
|
|
| |
| x = self.shared_embedding(encoder_input_tokens.astype('int32')) |
| x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( |
| x, deterministic=not use_dropout |
| ) |
| x = x.astype(cfg.dtype) |
| n_layer = cfg.num_encoder_layers - self.num_fusion_layers |
| frozen_layer_id = int(n_layer * 0.8) - 1 |
| for lyr in range(n_layer): |
| |
| x = t5_network.EncoderLayer(config=cfg, relative_embedding=rel_emb)( |
| x, mask_matrix, deterministic=not use_dropout |
| ) |
| if frozen_base and lyr == frozen_layer_id: |
| x = jax.lax.stop_gradient(x) |
| return x, encoder_mask |
|
|
|
|
| class FusedT5Encoder(nn.Module): |
| """T5 encoder as a separate model which fuse multi-modal input. |
| |
| This module contains the encoder part of a pretrained T5. It is useful when |
| adopting the pretrained T5 encoder as a part of a larger network. Note that |
| the embedding layer should be created outside the module and provided as a |
| parameter `shared_embedding` to share it in other parts of the network (e.g., |
| text encoder). If `shared_embedding` is not provided, the embedding layer is |
| created within the module. |
| |
| Attributes: |
| vocab_size: Size of the vocabulary. |
| emb_dim: Size of the embeddings. |
| num_heads: Number of attention heads. |
| num_encoder_layers: Number of encoder layers. |
| num_decoder_layers: Number of decoder layers. |
| head_dim: Size of the embeddings in each head. |
| mlp_dim: Size of the MLP output embeddings. |
| dropout_rate: Dropout rate. |
| dtype: Data type. |
| mlp_activations: Sequence of activations in MLP. |
| logits_via_embedding: Use the embedding weights for computing logits. |
| """ |
|
|
| vocab_size: int |
| emb_dim: int |
| num_heads: int |
| num_encoder_layers: int |
| num_decoder_layers: int |
| num_fusion_layers: int |
| head_dim: int |
| mlp_dim: int |
| dropout_rate: float |
| dtype: str = 'bfloat16' |
| mlp_activations: Sequence[str] = ('gelu', 'linear') |
| logits_via_embedding: bool = False |
|
|
| def setup(self): |
| self.t5_config = t5_network.T5Config( |
| vocab_size=self.vocab_size, |
| emb_dim=self.emb_dim, |
| num_heads=self.num_heads, |
| num_encoder_layers=self.num_encoder_layers, |
| num_decoder_layers=self.num_decoder_layers, |
| head_dim=self.head_dim, |
| mlp_dim=self.mlp_dim, |
| dropout_rate=self.dropout_rate, |
| dtype=self.dtype, |
| mlp_activations=self.mlp_activations, |
| logits_via_embedding=self.logits_via_embedding, |
| ) |
|
|
| @nn.compact |
| def __call__( |
| self, |
| fused_input_embs, |
| encoder_input_embs=None, |
| encoder_mask=None, |
| fused_mask=None, |
| att_mask=None, |
| use_dropout=True, |
| output=False, |
| ): |
| """Function to fuse text and imaget embedding. |
| |
| encode both the encoded text embedding (encoder_input_embs) and |
| encoded image embedding (fused_input_embs) together using |
| self-attentive Transformer. |
| |
| Args: |
| fused_input_embs: pre-encoded embeddings of other modalities |
| encoder_input_embs: encoded text embedding sequence |
| encoder_mask: mask for encoding part |
| fused_mask: mask for fusion part |
| att_mask: pre-computed attention product to each layer's output |
| use_dropout: whether to use dropout. |
| output: whether it's output layer. |
| |
| Returns: |
| Sequence of token embedding after fusion |
| """ |
| cfg = self.t5_config |
| if encoder_input_embs is not None: |
| x = jnp.concatenate([encoder_input_embs, fused_input_embs], axis=1) |
| else: |
| x = fused_input_embs |
| rel_emb = t5_layers.RelativePositionBiases( |
| num_buckets=32, |
| max_distance=128, |
| num_heads=self.t5_config.num_heads, |
| dtype=self.t5_config.dtype, |
| embedding_init=nn.initializers.variance_scaling( |
| 1.0, 'fan_avg', 'uniform' |
| ), |
| ) |
|
|
| if encoder_mask is not None: |
| if fused_mask is None: |
| pad_width = fused_input_embs.shape[1] |
| fused_mask = jnp.pad( |
| array=encoder_mask, |
| pad_width=((0, 0), (0, pad_width)), |
| mode='constant', |
| constant_values=1.0, |
| ) |
| else: |
| fused_mask = jnp.concatenate([encoder_mask, fused_mask], axis=1) |
|
|
| mask_matrix = t5_layers.make_attention_mask( |
| fused_mask, fused_mask, dtype=cfg.dtype |
| ) |
| attn_weights_all_layers = [] |
| for _ in range( |
| cfg.num_encoder_layers - self.num_fusion_layers, cfg.num_encoder_layers |
| ): |
| |
| x, attn_weights = FusionEncoderLayer( |
| config=cfg, relative_embedding=rel_emb |
| )( |
| x, |
| encoder_mask=mask_matrix, |
| att_mask=att_mask, |
| deterministic=not use_dropout, |
| ) |
| attn_weights_all_layers += [attn_weights] |
| if output: |
| x = t5_layers.LayerNorm(dtype=cfg.dtype)(x) |
| if att_mask is not None: |
| x = x * att_mask |
| x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=not use_dropout) |
| return x, fused_mask, attn_weights_all_layers |
|
|
|
|
| class FusionEncoderLayer(nn.Module): |
| """Transformer encoder layer.""" |
|
|
| config: t5_network.T5Config |
| relative_embedding: nn.Module |
|
|
| @nn.compact |
| def __call__( |
| self, inputs, att_mask=None, encoder_mask=None, deterministic=False |
| ): |
| cfg = self.config |
|
|
| |
| encoder_bias = self.relative_embedding( |
| inputs.shape[-2], inputs.shape[-2], True |
| ) |
|
|
| |
| assert inputs.ndim == 3 |
| x = t5_layers.LayerNorm(dtype=cfg.dtype)(inputs) |
| if att_mask is not None: |
| x = x * att_mask |
| |
| x, attn_weights = MultiHeadDotProductAttention( |
| num_heads=cfg.num_heads, |
| dtype=cfg.dtype, |
| head_dim=cfg.head_dim, |
| dropout_rate=cfg.dropout_rate, |
| float32_logits=cfg.float32_attention_logits, |
| )(x, x, encoder_mask, encoder_bias, deterministic=deterministic) |
| x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( |
| x, deterministic=deterministic |
| ) |
| x = x + inputs |
|
|
| |
| y = t5_layers.LayerNorm(dtype=cfg.dtype)(x) |
| |
| y = t5_layers.MlpBlock( |
| intermediate_dim=cfg.mlp_dim, |
| activations=cfg.mlp_activations, |
| intermediate_dropout_rate=cfg.dropout_rate, |
| dtype=cfg.dtype, |
| )(y, deterministic=deterministic) |
| y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( |
| y, deterministic=deterministic |
| ) |
| y = y + x |
|
|
| return y, attn_weights |
|
|
|
|
| class PerceiverEncoder(nn.Module): |
| """Reimplementation of Perceiver. |
| |
| Perceiver: General Perception with Iterative Attention |
| (https://arxiv.org/abs/2103.03206) |
| """ |
|
|
| perceiver_output_dim: int |
| vocab_size: int |
| emb_dim: int |
| num_heads: int |
| num_encoder_layers: int |
| num_decoder_layers: int |
| num_fusion_layers: int |
| head_dim: int |
| mlp_dim: int |
| dropout_rate: float |
| dtype: str = 'bfloat16' |
| mlp_activations: Sequence[str] = ('gelu', 'linear') |
| logits_via_embedding: bool = False |
|
|
| def setup(self): |
| self.t5_config = t5_network.T5Config( |
| vocab_size=self.vocab_size, |
| emb_dim=self.emb_dim, |
| num_heads=self.num_heads, |
| num_encoder_layers=self.num_encoder_layers, |
| num_decoder_layers=self.num_decoder_layers, |
| head_dim=self.head_dim, |
| mlp_dim=self.mlp_dim, |
| dropout_rate=self.dropout_rate, |
| dtype=self.dtype, |
| mlp_activations=self.mlp_activations, |
| logits_via_embedding=self.logits_via_embedding, |
| ) |
|
|
| self.perceive_embedding = self.param( |
| 'perceive_embedding', |
| nn.initializers.normal(stddev=1.0), |
| (1, self.perceiver_output_dim, self.emb_dim), |
| jnp.float32, |
| ) |
| v = jnp.arange(self.perceiver_output_dim) |
| self.batch_triangle_select = jax.vmap( |
| functools.partial(_mask_select, mask=v < v.reshape([-1, 1])) |
| ) |
|
|
| def linear_disentangle(self, y): |
| mean = y.mean(axis=-2, keepdims=True) |
| norm_y = l2_normalize(y - mean) |
| pairwise_mat = jnp.square(jnp.einsum('bqd,btd->bqt', norm_y, norm_y)) |
| masked_mat = self.batch_triangle_select(pairwise_mat) |
| return jnp.mean(masked_mat) |
|
|
| @nn.compact |
| def __call__(self, encoded, encoded_mask, use_dropout=False): |
| cfg = self.t5_config |
| rel_emb = t5_layers.RelativePositionBiases( |
| num_buckets=32, |
| max_distance=128, |
| num_heads=cfg.num_heads, |
| dtype=cfg.dtype, |
| embedding_init=nn.initializers.variance_scaling( |
| 1.0, 'fan_avg', 'uniform' |
| ), |
| ) |
|
|
| |
| encoded = t5_layers.LayerNorm(dtype=cfg.dtype)(encoded) |
| bsz = encoded.shape[0] |
| y = jnp.asarray(self.perceive_embedding, dtype=cfg.dtype) |
| y = jnp.repeat(y, bsz, axis=0) |
| y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( |
| y, deterministic=not use_dropout |
| ) |
| y = y.astype(cfg.dtype) |
|
|
| mask = jnp.ones([bsz, self.perceiver_output_dim]).astype(bool) |
| encoder_decoder_mask = t5_layers.make_attention_mask( |
| mask, encoded_mask, dtype=self.dtype |
| ) |
|
|
| for _ in range(self.num_fusion_layers): |
| |
| y = t5_network.DecoderLayer(config=cfg, relative_embedding=rel_emb)( |
| y, |
| encoded, |
| deterministic=not use_dropout, |
| encoder_decoder_mask=encoder_decoder_mask, |
| decode=False, |
| ) |
|
|
| return y * 4, mask, self.linear_disentangle(y) |
|
|
|
|
| def dot_product_attention( |
| query: constants.JTensor, |
| key: constants.JTensor, |
| value: constants.JTensor, |
| bias: Optional[constants.JTensor] = None, |
| dropout_rng: Optional[constants.JTensor] = None, |
| dropout_rate: float = 0.0, |
| deterministic: bool = False, |
| dtype: constants.DType = jnp.float32, |
| float32_logits: bool = False, |
| ): |
| """Computes dot-product attention given query, key, and value. |
| |
| This is the core function for applying attention based on |
| https://arxiv.org/abs/1706.03762. It calculates the attention weights given |
| query and key and combines the values using the attention weights. |
| |
| Args: |
| query: queries for calculating attention with shape of `[batch, q_length, |
| num_heads, qk_depth_per_head]`. |
| key: keys for calculating attention with shape of `[batch, kv_length, |
| num_heads, qk_depth_per_head]`. |
| value: values to be used in attention with shape of `[batch, kv_length, |
| num_heads, v_depth_per_head]`. |
| bias: bias for the attention weights. This should be broadcastable to the |
| shape `[batch, num_heads, q_length, kv_length]` This can be used for |
| incorporating causal masks, padding masks, proximity bias, etc. |
| dropout_rng: JAX PRNGKey: to be used for dropout |
| dropout_rate: dropout rate |
| deterministic: bool, deterministic or not (to apply dropout) |
| dtype: the dtype of the computation (default: float32) |
| float32_logits: bool, if True then compute logits in float32 to avoid |
| numerical issues with bfloat16. |
| |
| Returns: |
| Output of shape `[batch, length, num_heads, v_depth_per_head]`. |
| """ |
| assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' |
| assert ( |
| query.shape[:-3] == key.shape[:-3] == value.shape[:-3] |
| ), 'q, k, v batch dims must match.' |
| assert ( |
| query.shape[-2] == key.shape[-2] == value.shape[-2] |
| ), 'q, k, v num_heads must match.' |
| assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' |
| assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' |
|
|
| |
| if float32_logits: |
| query = query.astype(jnp.float32) |
| key = key.astype(jnp.float32) |
|
|
| |
| attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) |
|
|
| |
| if bias is not None: |
| attn_weights = attn_weights + bias.astype(attn_weights.dtype) |
|
|
| |
| attn_weights = jax.nn.softmax(attn_weights).astype(dtype) |
|
|
| |
| if not deterministic and dropout_rate > 0.0: |
| keep_prob = 1.0 - dropout_rate |
| |
| |
| dropout_shape = list(attn_weights.shape) |
| dropout_shape[-2] = 1 |
| keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape) |
| keep = jnp.broadcast_to(keep, attn_weights.shape) |
| multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( |
| keep_prob, dtype=dtype |
| ) |
| attn_weights = attn_weights * multiplier |
|
|
| |
| return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value), attn_weights |
|
|
|
|
| class MultiHeadDotProductAttention(nn.Module): |
| """Multi-head dot-product attention. |
| |
| Attributes: |
| num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) |
| should be divisible by the number of heads. |
| head_dim: dimension of each head. |
| dtype: the dtype of the computation. |
| dropout_rate: dropout rate |
| kernel_init: initializer for the kernel of the Dense layers. |
| float32_logits: bool, if True then compute logits in float32 to avoid |
| numerical issues with bfloat16. |
| """ |
|
|
| num_heads: int |
| head_dim: int |
| dtype: constants.DType = jnp.float32 |
| dropout_rate: float = 0.0 |
| kernel_init: constants.Initializer = nn.initializers.variance_scaling( |
| 1.0, 'fan_in', 'normal' |
| ) |
| float32_logits: bool = False |
|
|
| @nn.compact |
| def __call__( |
| self, |
| inputs_q: constants.JTensor, |
| inputs_kv: constants.JTensor, |
| mask: Optional[constants.JTensor] = None, |
| bias: Optional[constants.JTensor] = None, |
| *, |
| decode: bool = False, |
| deterministic: bool = False, |
| ) -> constants.JTensor: |
| """Applies multi-head dot product attention on the input data. |
| |
| Projects the inputs into multi-headed query, key, and value vectors, |
| applies dot-product attention and project the results to an output vector. |
| |
| There are two modes: decoding and non-decoding (e.g., training). The mode is |
| determined by `decode` argument. For decoding, this method is called twice, |
| first to initialize the cache and then for an actual decoding process. The |
| two calls are differentiated by the presence of 'cached_key' in the variable |
| dict. In the cache initialization stage, the cache variables are initialized |
| as zeros and will be filled in the subsequent decoding process. |
| |
| In the cache initialization call, `inputs_q` has a shape [batch, length, |
| q_features] and `inputs_kv`: [batch, length, kv_features]. During the |
| incremental decoding stage, query, key and value all have the shape [batch, |
| 1, qkv_features] corresponding to a single step. |
| |
| Args: |
| inputs_q: input queries of shape `[batch, q_length, q_features]`. |
| inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. |
| mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. |
| bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. |
| decode: Whether to prepare and use an autoregressive cache. |
| deterministic: Disables dropout if set to True. |
| |
| Returns: |
| output of shape `[batch, length, q_features]`. |
| """ |
| projection = functools.partial( |
| t5_layers.DenseGeneral, |
| axis=-1, |
| features=(self.num_heads, self.head_dim), |
| kernel_axes=('embed', 'joined_kv'), |
| dtype=self.dtype, |
| ) |
|
|
| |
| |
| |
| depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) |
| query_init = lambda *args: self.kernel_init(*args) / depth_scaling |
|
|
| |
| |
| query = projection(kernel_init=query_init)(inputs_q) |
| key = projection(kernel_init=self.kernel_init)(inputs_kv) |
| value = projection(kernel_init=self.kernel_init)(inputs_kv) |
|
|
| query = t5_layers.with_sharding_constraint( |
| query, ('batch', 'length', 'heads', 'kv') |
| ) |
| key = t5_layers.with_sharding_constraint( |
| key, ('batch', 'length', 'heads', 'kv') |
| ) |
| value = t5_layers.with_sharding_constraint( |
| value, ('batch', 'length', 'heads', 'kv') |
| ) |
|
|
| if decode: |
| |
| is_initialized = self.has_variable('cache', 'cached_key') |
| |
| |
| |
| |
| |
| swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) |
| cached_key = self.variable( |
| 'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype |
| ) |
| cached_value = self.variable( |
| 'cache', |
| 'cached_value', |
| jnp.zeros, |
| swap_dims(value.shape), |
| value.dtype, |
| ) |
| cache_index = self.variable( |
| 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) |
| ) |
| if is_initialized: |
| batch, num_heads, head_dim, length = cached_key.value.shape |
| |
| |
| |
| expected_shape = (batch, 1, num_heads, head_dim) |
| if expected_shape != query.shape: |
| raise ValueError( |
| 'Autoregressive cache shape error, ' |
| 'expected query shape %s instead got %s.' |
| % (expected_shape, query.shape) |
| ) |
|
|
| |
| cur_index = cache_index.value |
| one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) |
| |
| |
| |
| |
| |
| one_token_key = jnp.moveaxis(key, -3, -1) |
| one_token_value = jnp.moveaxis(value, -3, -1) |
| |
| |
| |
| key = cached_key.value + one_token_key * one_hot_indices |
| value = cached_value.value + one_token_value * one_hot_indices |
| cached_key.value = key |
| cached_value.value = value |
| cache_index.value = cache_index.value + 1 |
| |
| key = jnp.moveaxis(key, -1, -3) |
| value = jnp.moveaxis(value, -1, -3) |
|
|
| |
| |
| |
| mask = t5_layers.combine_masks( |
| mask, |
| jnp.broadcast_to( |
| jnp.arange(length) <= cur_index, |
| |
| |
| |
| |
| (batch, 1, 1, length), |
| ), |
| ) |
|
|
| |
| |
| if bias is not None: |
| |
| |
| |
| bias = t5_layers.dynamic_vector_slice_in_dim( |
| jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 |
| ) |
|
|
| |
| if mask is not None: |
| |
| attention_bias = jax.lax.select( |
| mask > 0, |
| jnp.full(mask.shape, 0.0).astype(self.dtype), |
| jnp.full(mask.shape, -1e10).astype(self.dtype), |
| ) |
| else: |
| attention_bias = None |
|
|
| |
| if bias is not None: |
| attention_bias = t5_layers.combine_biases(attention_bias, bias) |
|
|
| dropout_rng = None |
| if not deterministic and self.dropout_rate > 0.0: |
| dropout_rng = self.make_rng('dropout') |
|
|
| |
| x, attn_weights = dot_product_attention( |
| query, |
| key, |
| value, |
| bias=attention_bias, |
| dropout_rng=dropout_rng, |
| dropout_rate=self.dropout_rate, |
| deterministic=deterministic, |
| dtype=self.dtype, |
| float32_logits=self.float32_logits, |
| ) |
|
|
| |
| out = t5_layers.DenseGeneral( |
| features=inputs_q.shape[-1], |
| axis=(-2, -1), |
| kernel_init=self.kernel_init, |
| kernel_axes=('joined_kv', 'embed'), |
| dtype=self.dtype, |
| )(x) |
| return out, attn_weights |
|
|