""" An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 """ import functools from typing import NamedTuple import flax.linen as nn import jax import jax.lax as lax import jax.numpy as jnp from einops import rearrange """ Computing ffn blockwise without materializing the large hidden tensor, training 4x longer sequences than the memory-efficient transformer. Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 """ def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable() # inputs: (batch, seq_len, dim) # chunk_size: the chunk size to split the sequence inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) def scan_ffn(remat_ffn, carry, hidden_states): outputs = remat_ffn(hidden_states, deterministic=deterministic) return carry, outputs scan_axis = inputs.ndim - 2 _, res = nn.scan( scan_ffn, variable_broadcast="params", split_rngs={"params": False, "dropout": True}, in_axes=scan_axis, out_axes=scan_axis, )(remat_ffn, None, inputs) res = rearrange(res, 'b c n d -> b (c n) d') return res """ Compute attention blockwise without materializing the full attention matrix, initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. """ def blockwise_attn( query, key, value, bias=None, deterministic=True, dropout_rng=None, attn_pdrop=0.0, causal=True, query_chunk_size=2048, key_chunk_size=2048, dtype=jnp.float32, policy=jax.checkpoint_policies.nothing_saveable(), precision=None, float32_logits=True, prevent_cse=True, ): # query, key, value: (batch, seq_len, num_heads, dim_per_head) # bias: (batch, seq_len) can be used to mask out attention (e.g. padding) # causal: whether to use causal mask # policy: one of jax.checkpoint_policies query = query / jnp.sqrt(query.shape[-1]).astype(dtype) if float32_logits: query = query.astype(jnp.float32) key = key.astype(jnp.float32) batch, q_len, num_heads, dim_per_head = query.shape batch, kv_len, num_heads, dim_per_head = key.shape batch, kv_len, num_heads, dim_per_head = value.shape num_q = q_len // query_chunk_size num_kv = kv_len // key_chunk_size query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) query = jnp.moveaxis(query, 1, 0) key = jnp.moveaxis(key, 1, 0) value = jnp.moveaxis(value, 1, 0) if bias is not None: for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): assert bias_dim == 1 or bias_dim == broadcast_dim if not deterministic and attn_pdrop > 0.0: attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) else: attn_dropout = None _chunk_bias_fn = functools.partial( _chunk_attention_bias, query_chunk_size, key_chunk_size, bias, deterministic, attn_dropout, attn_pdrop, causal, dtype) def scan_attention(args): query_chunk, query_chunk_idx = args @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) def scan_kv_block(carry, args): key_chunk, value_chunk, key_chunk_idx = args (numerator, denominator, prev_max_score) = carry attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) attn_weights = attn_weights + bias_chunk max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jnp.maximum(prev_max_score, max_score) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum( 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision ) correction = jnp.exp(prev_max_score - max_score) numerator = numerator * correction + exp_values denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) return Carry(numerator, denominator, max_score), None def skip_upper_half(carry, args): key_chunk, value_chunk, key_chunk_idx = args skip_block = jnp.array(False) if causal: skip_block = query_chunk_idx < key_chunk_idx return jax.lax.cond( skip_block, lambda carry, args: (carry, None), scan_kv_block, carry, args, ) init_carry = Carry( jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), ) (numerator, denominator, max_score), _ = lax.scan( skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) ) outputs = (numerator / denominator).astype(dtype) return outputs _, res = lax.scan( lambda _, x: ((), scan_attention(x)), (), xs=(query, jnp.arange(0, num_q)) ) res = rearrange(res, 'n b c h d -> b (n c) h d') return res class Carry(NamedTuple): numerator: jax.Array denominator: jax.Array max_so_far: jax.Array def _chunk_attention_bias(query_chunk_size, key_chunk_size, bias, deterministic, attn_dropout, attn_pdrop, causal, dtype, query_chunk_idx, key_chunk_idx): query_offset = query_chunk_idx * query_chunk_size key_offset = key_chunk_idx * key_chunk_size chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) if bias is not None: chunk_bias = lax.dynamic_slice( bias, start_indices=(0, 0, query_offset, key_offset), slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), ) if causal: query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) offset = query_offset - key_offset query_idx += offset causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) if not deterministic and attn_pdrop > 0.0: attn_dropout_slice = lax.dynamic_slice( attn_dropout, start_indices=(0, 0, query_offset, key_offset), slice_sizes=( *attn_dropout.shape[:2], min(attn_dropout.shape[-2], query_chunk_size), min(attn_dropout.shape[-1], key_chunk_size), ), ) chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min return chunk_bias.astype(dtype) if __name__ == '__main__': # test def reference_attn(query, key, value, causal, dtype): query = query / jnp.sqrt(query.shape[-1]).astype(dtype) logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) if causal: mask_value = jnp.finfo(logits.dtype).min _, q_seq_len, _, _ = query.shape _, kv_seq_len, _, _ = key.shape mask_shape = (q_seq_len, kv_seq_len) row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) causal_mask = (row_ids < col_ids)[None, None, :, :] logits = logits + jnp.where(causal_mask, mask_value, 0.0) weights = jax.nn.softmax(logits, axis=-1) out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) return out # random inputs shape = (1, 32, 8, 64) query = jax.random.normal(jax.random.PRNGKey(0), shape) key = jax.random.normal(jax.random.PRNGKey(1), shape) value = jax.random.normal(jax.random.PRNGKey(2), shape) causal = True chunk_size = 4 policy = jax.checkpoint_policies.nothing_saveable() blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) reference = reference_attn(query, key, value, causal, 'float32') assert jnp.allclose(reference, blockwise, atol=1e-6)