# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """T5.1.1 Transformer model.""" from typing import Any, Sequence from flax import linen as nn from flax import struct from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from t5x.examples.scalable_t5 import layers with_sharding_constraint = nn_partitioning.with_sharding_constraint scan_with_axes = nn_partitioning.scan_with_axes remat = nn_partitioning.remat ScanIn = nn_partitioning.ScanIn @struct.dataclass class T5Config: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" vocab_size: int # Activation dtypes. dtype: Any = jnp.float32 emb_dim: int = 512 num_heads: int = 8 num_encoder_layers: int = 6 num_decoder_layers: int = 6 head_dim: int = 64 mlp_dim: int = 2048 # Activation functions are retrieved from Flax. mlp_activations: Sequence[str] = ('relu',) dropout_rate: float = 0.1 # If `True`, the embedding weights are used in the decoder output layer. logits_via_embedding: bool = False # minimal, full, or none remat_policy: str = 'none' scan_layers: bool = True param_scan_axis: int = 1 class EncoderLayer(nn.Module): """Transformer encoder layer.""" config: T5Config @nn.compact def __call__(self, inputs, encoder_mask=None, deterministic=False): cfg = self.config # Relative position embedding as attention biases. encoder_bias = layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=cfg.num_heads, dtype=cfg.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform'), name='relative_posemb')(inputs.shape[-2], inputs.shape[-2], True) # Attention block. assert inputs.ndim == 3 inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) x = layers.LayerNorm( dtype=cfg.dtype, name='pre_attention_layer_norm')( inputs) x = with_sharding_constraint(x, ('batch', 'length', 'embed')) # [batch, length, emb_dim] -> [batch, length, emb_dim] x = layers.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, head_dim=cfg.head_dim, dropout_rate=cfg.dropout_rate, name='attention')( x, x, encoder_mask, encoder_bias, deterministic=deterministic) x = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic) x = x + inputs x = with_sharding_constraint(x, ('batch', 'length', 'embed')) # MLP block. y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) y = with_sharding_constraint(y, ('batch', 'length', 'embed')) # [batch, length, emb_dim] -> [batch, length, emb_dim] y = layers.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, name='mlp', )(y, deterministic=deterministic) y = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic) y = y + x y = with_sharding_constraint(y, ('batch', 'length', 'embed')) if cfg.scan_layers: return y, None else: return y class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" config: T5Config @nn.compact def __call__(self, inputs, encoded, decoder_mask=None, encoder_decoder_mask=None, deterministic=False, decode=False, max_decode_length=None): cfg = self.config # Relative position embedding as attention biases. l = max_decode_length if decode and max_decode_length else inputs.shape[-2] decoder_bias = layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=cfg.num_heads, dtype=cfg.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform'), name='relative_posemb')(l, l, False) inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] x = layers.LayerNorm( dtype=cfg.dtype, name='pre_self_attention_layer_norm')( inputs) x = with_sharding_constraint(x, ('batch', 'length', 'embed')) # Self-attention block x = layers.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, head_dim=cfg.head_dim, dropout_rate=cfg.dropout_rate, name='self_attention')( x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode) x = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic) x = x + inputs x = with_sharding_constraint(x, ('batch', 'length', 'embed')) # Encoder-Decoder block. y = layers.LayerNorm( dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( x) y = with_sharding_constraint(y, ('batch', 'length', 'embed')) y = layers.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, head_dim=cfg.head_dim, dropout_rate=cfg.dropout_rate, name='encoder_decoder_attention')( y, encoded, encoder_decoder_mask, deterministic=deterministic) y = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic) y = y + x y = with_sharding_constraint(y, ('batch', 'length', 'embed')) # MLP block. z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) z = with_sharding_constraint(z, ('batch', 'length', 'embed')) z = layers.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, name='mlp', )(z, deterministic=deterministic) z = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( z, deterministic=deterministic) z = z + y z = with_sharding_constraint(z, ('batch', 'length', 'embed')) if cfg.scan_layers: return z, None else: return z class Encoder(nn.Module): """A stack of encoder layers.""" config: T5Config shared_embedding: nn.Module @nn.compact def __call__(self, encoder_input_tokens, encoder_mask=None, deterministic=False): cfg = self.config assert encoder_input_tokens.ndim == 2 # [batch, length] # [batch, length] -> [batch, length, emb_dim] x = self.shared_embedding(encoder_input_tokens.astype('int32')) x = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic) x = x.astype(cfg.dtype) BlockLayer = EncoderLayer if cfg.remat_policy not in (None, 'none'): if cfg.remat_policy == 'minimal': policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims else: policy = None BlockLayer = remat( # pylint: disable=invalid-name BlockLayer, prevent_cse=not cfg.scan_layers, policy=policy, static_argnums=(2,)) if cfg.scan_layers: initializing = self.is_mutable_collection('params') params_spec = ( cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)) cache_spec = 0 x, _ = scan_with_axes( BlockLayer, variable_axes={ 'params': params_spec, 'cache': cache_spec, }, split_rngs={ 'params': True, 'dropout': True }, in_axes=(nn.broadcast, nn.broadcast), length=cfg.num_encoder_layers, axis_name='layers')( config=cfg, name='layers')(x, encoder_mask, deterministic) else: for lyr in range(cfg.num_encoder_layers): # [batch, length, emb_dim] -> [batch, length, emb_dim] x = BlockLayer( config=cfg, name=f'layers_{lyr}')(x, encoder_mask, deterministic) x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" config: T5Config shared_embedding: nn.Module @nn.compact def __call__(self, encoded, decoder_input_tokens, decoder_positions=None, decoder_mask=None, encoder_decoder_mask=None, deterministic=False, decode=False, max_decode_length=None): cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] # [batch, length] -> [batch, length, emb_dim] y = self.shared_embedding(decoder_input_tokens.astype('int32')) y = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic) y = y.astype(cfg.dtype) BlockLayer = DecoderLayer if cfg.remat_policy not in (None, 'none'): if cfg.remat_policy == 'minimal': policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims else: policy = None BlockLayer = remat( # pylint: disable=invalid-name BlockLayer, prevent_cse=not cfg.scan_layers, policy=policy, static_argnums=(4, 5, 6)) if cfg.scan_layers: initializing = self.is_mutable_collection('params') params_spec = ( cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)) cache_spec = 0 y, _ = scan_with_axes( BlockLayer, variable_axes={ 'params': params_spec, 'cache': cache_spec }, split_rngs={ 'params': True, 'dropout': True }, in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), length=cfg.num_decoder_layers, axis_name='layers')( config=cfg, name='layers')( y, encoded, decoder_mask, encoder_decoder_mask, deterministic, decode, max_decode_length) else: for lyr in range(cfg.num_decoder_layers): # [batch, length, emb_dim] -> [batch, length, emb_dim] y = BlockLayer( config=cfg, name=f'layers_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, deterministic=deterministic, decode=decode, max_decode_length=max_decode_length) y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) y = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic) # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = self.shared_embedding.attend(y) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = layers.DenseGeneral( cfg.vocab_size, dtype=jnp.float32, # Use float32 for stabiliity. kernel_axes=('embed', 'vocab'), name='logits_dense')( y) return logits class Transformer(nn.Module): """An encoder-decoder Transformer model.""" config: T5Config # needed only for janky models.py scan_layers detection. scan_layers: bool = struct.field(init=False) def __post_init__(self): super().__post_init__() # needed only for janky models.py scan_layers detection. object.__setattr__(self, 'scan_layers', object.__getattribute__(self, 'config').scan_layers) def setup(self): cfg = self.config self.shared_embedding = layers.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, dtype=cfg.dtype, attend_dtype=jnp.float32, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), one_hot=True, name='token_embedder') self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) def encode(self, encoder_input_tokens, encoder_segment_ids=None, enable_dropout=True): """Applies Transformer encoder-branch on the inputs.""" cfg = self.config assert encoder_input_tokens.ndim == 2 # (batch, len) # Make padding attention mask. encoder_mask = layers.make_attention_mask( encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if encoder_segment_ids is not None: encoder_mask = layers.combine_masks( encoder_mask, layers.make_attention_mask( encoder_segment_ids, encoder_segment_ids, jnp.equal, dtype=cfg.dtype)) return self.encoder( encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) def decode( self, encoded, encoder_input_tokens, # only needed for masks decoder_input_tokens, decoder_target_tokens, encoder_segment_ids=None, decoder_segment_ids=None, decoder_positions=None, enable_dropout=True, decode=False, max_decode_length=None): """Applies Transformer decoder-branch on encoded-input and target.""" cfg = self.config # Make padding attention masks. if decode: # Do not mask decoder attention based on targets padding at # decoding/inference time. decoder_mask = None encoder_decoder_mask = layers.make_attention_mask( jnp.ones_like(decoder_target_tokens), encoder_input_tokens > 0, dtype=cfg.dtype) else: decoder_mask = layers.make_decoder_mask( decoder_target_tokens=decoder_target_tokens, dtype=cfg.dtype, decoder_segment_ids=decoder_segment_ids) encoder_decoder_mask = layers.make_attention_mask( decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) # Add segmentation block-diagonal attention masks if using segmented data. if encoder_segment_ids is not None: if decode: raise ValueError( 'During decoding, packing should not be used but ' '`encoder_segment_ids` was passed to `Transformer.decode`.') encoder_decoder_mask = layers.combine_masks( encoder_decoder_mask, layers.make_attention_mask( decoder_segment_ids, encoder_segment_ids, jnp.equal, dtype=cfg.dtype)) logits = self.decoder( encoded, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, deterministic=not enable_dropout, decode=decode, max_decode_length=max_decode_length) return logits def __call__(self, encoder_input_tokens, decoder_input_tokens, decoder_target_tokens, encoder_segment_ids=None, decoder_segment_ids=None, encoder_positions=None, decoder_positions=None, *, enable_dropout: bool = True, decode: bool = False): """Applies Transformer model on the inputs. This method requires both decoder_target_tokens and decoder_input_tokens, which is a shifted version of the former. For a packed dataset, it usually has additional processing applied. For example, the first element of each sequence has id 0 instead of the shifted EOS id from the previous sequence. Args: encoder_input_tokens: input data to the encoder. decoder_input_tokens: input token to the decoder. decoder_target_tokens: target token to the decoder. encoder_segment_ids: encoder segmentation info for packed examples. decoder_segment_ids: decoder segmentation info for packed examples. encoder_positions: encoder subsequence positions for packed examples. decoder_positions: decoder subsequence positions for packed examples. enable_dropout: Ensables dropout if set to True. decode: Whether to prepare and use an autoregressive cache. Returns: logits array from full transformer. """ encoded = self.encode( encoder_input_tokens, encoder_segment_ids=encoder_segment_ids, enable_dropout=enable_dropout) return self.decode( encoded, encoder_input_tokens, # only used for masks decoder_input_tokens, decoder_target_tokens, encoder_segment_ids=encoder_segment_ids, decoder_segment_ids=decoder_segment_ids, decoder_positions=decoder_positions, enable_dropout=enable_dropout, decode=decode)