Ahma-7B / EasyLM /scripts /benchmark_attention.py
aapot
Add training codes
a85f909
from functools import partial
from time import time
import os
import numpy as np
import jax
import jax.flatten_util
import jax.numpy as jnp
import mlxu
from EasyLM.bpt import blockwise_attn
from EasyLM.jax_utils import (
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
)
FLAGS, _ = mlxu.define_flags_with_default(
seed=42,
dtype='fp32',
embed_dim=2048,
n_heads=16,
ref_attn_seq_len=2048,
eff_attn_seq_len=16384,
batch_size=1,
query_chunk_size=2048,
key_chunk_size=2048,
warmup_steps=40,
steps=200,
)
def main(argv):
def random_kqv(rng_key, seq_len):
rng_generator = JaxRNG(rng_key)
kqv = []
for i in range(3):
kqv.append(
jax.random.normal(
rng_generator(),
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
dtype=get_float_dtype_by_name(FLAGS.dtype)
)
)
return tuple(kqv)
def reference_attn(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
def efficient_attention(query, key, value):
dtype = get_float_dtype_by_name(FLAGS.dtype)
return blockwise_attn(
query, key, value,
bias=None,
deterministic=True,
dropout_rng=None,
attn_pdrop=0.0,
causal=True,
query_chunk_size=FLAGS.query_chunk_size,
key_chunk_size=FLAGS.key_chunk_size,
dtype=get_float_dtype_by_name(FLAGS.dtype),
policy=jax.checkpoint_policies.nothing_saveable(),
precision=None,
float32_logits=True,
prevent_cse=True,
)
@partial(jax.jit, static_argnums=(1,))
def reference_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
def grad_fn(query, key, value):
out = reference_attn(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
@partial(jax.jit, static_argnums=(1,))
def efficient_attn_forward_backward(rng_key, seq_len):
@partial(jax.grad, argnums=(0, 1, 2))
def grad_fn(query, key, value):
out = efficient_attention(query, key, value)
return jnp.mean(out)
query, key, value = random_kqv(rng_key, seq_len)
return jax.flatten_util.ravel_pytree(
grad_fn(query, key, value)[1]
)[0].mean()
set_random_seed(FLAGS.seed)
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_ref_attn = time() - start_time
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
all_results = []
for i in range(FLAGS.warmup_steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
start_time = time()
all_results = []
for i in range(FLAGS.steps):
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
jax.block_until_ready(all_results)
elapsed_time_efficient_attn = time() - start_time
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
print(f'Efficiency: {efficiency:.3f}')
if __name__ == '__main__':
mlxu.run(main)