from functools import partial from time import time import os import numpy as np import jax import jax.flatten_util import jax.numpy as jnp import mlxu from EasyLM.bpt import blockwise_attn from EasyLM.jax_utils import ( get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG ) FLAGS, _ = mlxu.define_flags_with_default( seed=42, dtype='fp32', embed_dim=2048, n_heads=16, ref_attn_seq_len=2048, eff_attn_seq_len=16384, batch_size=1, query_chunk_size=2048, key_chunk_size=2048, warmup_steps=40, steps=200, ) def main(argv): def random_kqv(rng_key, seq_len): rng_generator = JaxRNG(rng_key) kqv = [] for i in range(3): kqv.append( jax.random.normal( rng_generator(), (FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads), dtype=get_float_dtype_by_name(FLAGS.dtype) ) ) return tuple(kqv) def reference_attn(query, key, value): dtype = get_float_dtype_by_name(FLAGS.dtype) query = query / jnp.sqrt(query.shape[-1]).astype(dtype) logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) mask_value = jnp.finfo(logits.dtype).min _, q_seq_len, _, _ = query.shape _, kv_seq_len, _, _ = key.shape mask_shape = (q_seq_len, kv_seq_len) row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) causal_mask = (row_ids < col_ids)[None, None, :, :] logits = logits + jnp.where(causal_mask, mask_value, 0.0) weights = jax.nn.softmax(logits, axis=-1) out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) return out def efficient_attention(query, key, value): dtype = get_float_dtype_by_name(FLAGS.dtype) return blockwise_attn( query, key, value, bias=None, deterministic=True, dropout_rng=None, attn_pdrop=0.0, causal=True, query_chunk_size=FLAGS.query_chunk_size, key_chunk_size=FLAGS.key_chunk_size, dtype=get_float_dtype_by_name(FLAGS.dtype), policy=jax.checkpoint_policies.nothing_saveable(), precision=None, float32_logits=True, prevent_cse=True, ) @partial(jax.jit, static_argnums=(1,)) def reference_attn_forward_backward(rng_key, seq_len): @partial(jax.grad, argnums=(0, 1, 2)) @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable()) def grad_fn(query, key, value): out = reference_attn(query, key, value) return jnp.mean(out) query, key, value = random_kqv(rng_key, seq_len) return jax.flatten_util.ravel_pytree( grad_fn(query, key, value)[1] )[0].mean() @partial(jax.jit, static_argnums=(1,)) def efficient_attn_forward_backward(rng_key, seq_len): @partial(jax.grad, argnums=(0, 1, 2)) def grad_fn(query, key, value): out = efficient_attention(query, key, value) return jnp.mean(out) query, key, value = random_kqv(rng_key, seq_len) return jax.flatten_util.ravel_pytree( grad_fn(query, key, value)[1] )[0].mean() set_random_seed(FLAGS.seed) jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) all_results = [] for i in range(FLAGS.warmup_steps): all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) jax.block_until_ready(all_results) start_time = time() all_results = [] for i in range(FLAGS.steps): all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) jax.block_until_ready(all_results) elapsed_time_ref_attn = time() - start_time print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds') all_results = [] for i in range(FLAGS.warmup_steps): all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) jax.block_until_ready(all_results) start_time = time() all_results = [] for i in range(FLAGS.steps): all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) jax.block_until_ready(all_results) elapsed_time_efficient_attn = time() - start_time print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds') flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2 efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio print(f'Efficiency: {efficiency:.3f}') if __name__ == '__main__': mlxu.run(main)