File size: 4,951 Bytes
a85f909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from functools import partial
from time import time
import os
import numpy as np
import jax
import jax.flatten_util
import jax.numpy as jnp
import mlxu
from EasyLM.bpt import blockwise_attn
from EasyLM.jax_utils import (
    get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
)


FLAGS, _ = mlxu.define_flags_with_default(
    seed=42,
    dtype='fp32',
    embed_dim=2048,
    n_heads=16,
    ref_attn_seq_len=2048,
    eff_attn_seq_len=16384,
    batch_size=1,
    query_chunk_size=2048,
    key_chunk_size=2048,
    warmup_steps=40,
    steps=200,
)


def main(argv):

    def random_kqv(rng_key, seq_len):
        rng_generator = JaxRNG(rng_key)
        kqv = []
        for i in range(3):
            kqv.append(
                jax.random.normal(
                    rng_generator(),
                    (FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
                    dtype=get_float_dtype_by_name(FLAGS.dtype)
                )
            )
        return tuple(kqv)

    def reference_attn(query, key, value):
        dtype = get_float_dtype_by_name(FLAGS.dtype)
        query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
        logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
        mask_value = jnp.finfo(logits.dtype).min
        _, q_seq_len, _, _ = query.shape
        _, kv_seq_len, _, _ = key.shape
        mask_shape = (q_seq_len, kv_seq_len)
        row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
        col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
        causal_mask = (row_ids < col_ids)[None, None, :, :]
        logits = logits + jnp.where(causal_mask, mask_value, 0.0)
        weights = jax.nn.softmax(logits, axis=-1)
        out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
        return out

    def efficient_attention(query, key, value):
        dtype = get_float_dtype_by_name(FLAGS.dtype)
        return blockwise_attn(
            query, key, value,
            bias=None,
            deterministic=True,
            dropout_rng=None,
            attn_pdrop=0.0,
            causal=True,
            query_chunk_size=FLAGS.query_chunk_size,
            key_chunk_size=FLAGS.key_chunk_size,
            dtype=get_float_dtype_by_name(FLAGS.dtype),
            policy=jax.checkpoint_policies.nothing_saveable(),
            precision=None,
            float32_logits=True,
            prevent_cse=True,
        )


    @partial(jax.jit, static_argnums=(1,))
    def reference_attn_forward_backward(rng_key, seq_len):
        @partial(jax.grad, argnums=(0, 1, 2))
        @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
        def grad_fn(query, key, value):
            out = reference_attn(query, key, value)
            return jnp.mean(out)

        query, key, value = random_kqv(rng_key, seq_len)
        return jax.flatten_util.ravel_pytree(
            grad_fn(query, key, value)[1]
        )[0].mean()

    @partial(jax.jit, static_argnums=(1,))
    def efficient_attn_forward_backward(rng_key, seq_len):
        @partial(jax.grad, argnums=(0, 1, 2))
        def grad_fn(query, key, value):
            out = efficient_attention(query, key, value)
            return jnp.mean(out)

        query, key, value = random_kqv(rng_key, seq_len)
        return jax.flatten_util.ravel_pytree(
            grad_fn(query, key, value)[1]
        )[0].mean()


    set_random_seed(FLAGS.seed)

    jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
    jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))

    all_results = []
    for i in range(FLAGS.warmup_steps):
        all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
    jax.block_until_ready(all_results)

    start_time = time()
    all_results = []
    for i in range(FLAGS.steps):
        all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))

    jax.block_until_ready(all_results)
    elapsed_time_ref_attn = time() - start_time
    print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')


    all_results = []
    for i in range(FLAGS.warmup_steps):
        all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
    jax.block_until_ready(all_results)


    start_time = time()
    all_results = []
    for i in range(FLAGS.steps):
        all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))

    jax.block_until_ready(all_results)
    elapsed_time_efficient_attn = time() - start_time
    print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')

    flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
    efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
    print(f'Efficiency: {efficiency:.3f}')


if __name__ == '__main__':
    mlxu.run(main)