Text Generation
Transformers
PyTorch
Safetensors
Finnish
llama
finnish
text-generation-inference
aapot commited on
Commit
0394e28
1 Parent(s): f7552ca

Update EasyLM

Browse files
EasyLM/bpt.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
3
+ Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
4
+ """
5
+
6
+ import functools
7
+ from typing import NamedTuple
8
+
9
+ import flax.linen as nn
10
+ import jax
11
+ import jax.lax as lax
12
+ import jax.numpy as jnp
13
+ from einops import rearrange
14
+
15
+ """
16
+ Computing ffn blockwise without materializing the large hidden tensor, training
17
+ 4x longer sequences than the memory-efficient transformer.
18
+ Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
19
+ """
20
+ def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
21
+ # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
22
+ # inputs: (batch, seq_len, dim)
23
+ # chunk_size: the chunk size to split the sequence
24
+ inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
25
+ def scan_ffn(remat_ffn, carry, hidden_states):
26
+ outputs = remat_ffn(hidden_states, deterministic=deterministic)
27
+ return carry, outputs
28
+ scan_axis = inputs.ndim - 2
29
+ _, res = nn.scan(
30
+ scan_ffn,
31
+ variable_broadcast="params",
32
+ split_rngs={"params": False, "dropout": True},
33
+ in_axes=scan_axis,
34
+ out_axes=scan_axis,
35
+ )(remat_ffn, None, inputs)
36
+ res = rearrange(res, 'b c n d -> b (c n) d')
37
+ return res
38
+
39
+
40
+ """
41
+ Compute attention blockwise without materializing the full attention matrix,
42
+ initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
43
+ flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
44
+ efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
45
+ Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
46
+ longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
47
+ """
48
+ def blockwise_attn(
49
+ query, key, value,
50
+ bias=None,
51
+ deterministic=True,
52
+ dropout_rng=None,
53
+ attn_pdrop=0.0,
54
+ causal=True,
55
+ query_chunk_size=2048,
56
+ key_chunk_size=2048,
57
+ dtype=jnp.float32,
58
+ policy=jax.checkpoint_policies.nothing_saveable(),
59
+ precision=None,
60
+ float32_logits=True,
61
+ prevent_cse=True,
62
+ ):
63
+ # query, key, value: (batch, seq_len, num_heads, dim_per_head)
64
+ # bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
65
+ # causal: whether to use causal mask
66
+ # policy: one of jax.checkpoint_policies
67
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
68
+ if float32_logits:
69
+ query = query.astype(jnp.float32)
70
+ key = key.astype(jnp.float32)
71
+
72
+ batch, q_len, num_heads, dim_per_head = query.shape
73
+ batch, kv_len, num_heads, dim_per_head = key.shape
74
+ batch, kv_len, num_heads, dim_per_head = value.shape
75
+
76
+ num_q = q_len // query_chunk_size
77
+ num_kv = kv_len // key_chunk_size
78
+ query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
79
+ key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
80
+ value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
81
+
82
+ query = jnp.moveaxis(query, 1, 0)
83
+ key = jnp.moveaxis(key, 1, 0)
84
+ value = jnp.moveaxis(value, 1, 0)
85
+
86
+ if bias is not None:
87
+ for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
88
+ assert bias_dim == 1 or bias_dim == broadcast_dim
89
+ if not deterministic and attn_pdrop > 0.0:
90
+ attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
91
+ attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
92
+ else:
93
+ attn_dropout = None
94
+
95
+ _chunk_bias_fn = functools.partial(
96
+ _chunk_attention_bias,
97
+ query_chunk_size, key_chunk_size, bias, deterministic,
98
+ attn_dropout, attn_pdrop, causal, dtype)
99
+
100
+ def scan_attention(args):
101
+ query_chunk, query_chunk_idx = args
102
+
103
+ @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
104
+ def scan_kv_block(carry, args):
105
+ key_chunk, value_chunk, key_chunk_idx = args
106
+ (numerator, denominator, prev_max_score) = carry
107
+ attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
108
+ bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
109
+ bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
110
+ attn_weights = attn_weights + bias_chunk
111
+
112
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
113
+ max_score = jnp.maximum(prev_max_score, max_score)
114
+ max_score = jax.lax.stop_gradient(max_score)
115
+ exp_weights = jnp.exp(attn_weights - max_score)
116
+ exp_values = jnp.einsum(
117
+ 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
118
+ )
119
+ correction = jnp.exp(prev_max_score - max_score)
120
+ numerator = numerator * correction + exp_values
121
+ denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
122
+ return Carry(numerator, denominator, max_score), None
123
+
124
+ def skip_upper_half(carry, args):
125
+ key_chunk, value_chunk, key_chunk_idx = args
126
+ skip_block = jnp.array(False)
127
+ if causal:
128
+ skip_block = query_chunk_idx < key_chunk_idx
129
+ return jax.lax.cond(
130
+ skip_block,
131
+ lambda carry, args: (carry, None),
132
+ scan_kv_block,
133
+ carry,
134
+ args,
135
+ )
136
+
137
+ init_carry = Carry(
138
+ jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
139
+ jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
140
+ (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
141
+ )
142
+ (numerator, denominator, max_score), _ = lax.scan(
143
+ skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
144
+ )
145
+ outputs = (numerator / denominator).astype(dtype)
146
+ return outputs
147
+
148
+ _, res = lax.scan(
149
+ lambda _, x: ((), scan_attention(x)),
150
+ (), xs=(query, jnp.arange(0, num_q))
151
+ )
152
+ res = rearrange(res, 'n b c h d -> b (n c) h d')
153
+ return res
154
+
155
+
156
+ class Carry(NamedTuple):
157
+ numerator: jax.Array
158
+ denominator: jax.Array
159
+ max_so_far: jax.Array
160
+
161
+
162
+ def _chunk_attention_bias(query_chunk_size, key_chunk_size,
163
+ bias, deterministic, attn_dropout, attn_pdrop, causal,
164
+ dtype, query_chunk_idx, key_chunk_idx):
165
+ query_offset = query_chunk_idx * query_chunk_size
166
+ key_offset = key_chunk_idx * key_chunk_size
167
+ chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
168
+ if bias is not None:
169
+ chunk_bias = lax.dynamic_slice(
170
+ bias,
171
+ start_indices=(0, 0, query_offset, key_offset),
172
+ slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
173
+ )
174
+
175
+ if causal:
176
+ query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
177
+ key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
178
+ offset = query_offset - key_offset
179
+ query_idx += offset
180
+ causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
181
+ chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
182
+
183
+ if not deterministic and attn_pdrop > 0.0:
184
+ attn_dropout_slice = lax.dynamic_slice(
185
+ attn_dropout,
186
+ start_indices=(0, 0, query_offset, key_offset),
187
+ slice_sizes=(
188
+ *attn_dropout.shape[:2],
189
+ min(attn_dropout.shape[-2], query_chunk_size),
190
+ min(attn_dropout.shape[-1], key_chunk_size),
191
+ ),
192
+ )
193
+ chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
194
+ return chunk_bias.astype(dtype)
195
+
196
+
197
+ if __name__ == '__main__':
198
+ # test
199
+ def reference_attn(query, key, value, causal, dtype):
200
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
201
+ logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
202
+ if causal:
203
+ mask_value = jnp.finfo(logits.dtype).min
204
+ _, q_seq_len, _, _ = query.shape
205
+ _, kv_seq_len, _, _ = key.shape
206
+ mask_shape = (q_seq_len, kv_seq_len)
207
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
208
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
209
+ causal_mask = (row_ids < col_ids)[None, None, :, :]
210
+ logits = logits + jnp.where(causal_mask, mask_value, 0.0)
211
+ weights = jax.nn.softmax(logits, axis=-1)
212
+ out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
213
+ return out
214
+
215
+ # random inputs
216
+ shape = (1, 32, 8, 64)
217
+ query = jax.random.normal(jax.random.PRNGKey(0), shape)
218
+ key = jax.random.normal(jax.random.PRNGKey(1), shape)
219
+ value = jax.random.normal(jax.random.PRNGKey(2), shape)
220
+
221
+ causal = True
222
+ chunk_size = 4
223
+ policy = jax.checkpoint_policies.nothing_saveable()
224
+
225
+ blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
226
+ reference = reference_attn(query, key, value, causal, 'float32')
227
+
228
+ assert jnp.allclose(reference, blockwise, atol=1e-6)
EasyLM/data.py CHANGED
@@ -3,6 +3,7 @@ import pprint
3
  import time
4
  from functools import partial
5
  import json
 
6
  from multiprocessing import Pool
7
 
8
  import h5py
@@ -59,6 +60,7 @@ class TextProcessor(object):
59
  config.add_bos_token = True
60
  config.add_eos_token = True
61
  config.prepend_text = ''
 
62
  if updates is not None:
63
  config.update(ConfigDict(updates).copy_and_resolve_references())
64
  return config
@@ -95,12 +97,26 @@ class TextProcessor(object):
95
  else:
96
  mask = 1.0
97
 
98
- if field == '<|bos|>':
99
- token_buffer.append(self.tokenizer.bos_token_id)
100
- loss_mask_buffer.append(mask)
101
- elif field == '<|eos|>':
102
- token_buffer.append(self.tokenizer.eos_token_id)
 
 
 
 
 
103
  loss_mask_buffer.append(mask)
 
 
 
 
 
 
 
 
 
104
  else:
105
  subfields = field.split('+')
106
  text = self.config.subfield_separator.join(
@@ -136,6 +152,7 @@ class HuggingfaceDataset(object):
136
  config.always_start_with_bos = False
137
  config.start_seek_loc = 0
138
  config.tokens_count_at_start = 0
 
139
 
140
  if updates is not None:
141
  config.update(ConfigDict(updates).copy_and_resolve_references())
@@ -163,6 +180,8 @@ class HuggingfaceDataset(object):
163
  while True:
164
  token_buffer = []
165
  loss_mask_buffer = []
 
 
166
  for index, example in enumerate(self._dataset):
167
  self._index = index
168
  if not self._eval_dataset and self._dataset_loc > index:
@@ -178,10 +197,10 @@ class HuggingfaceDataset(object):
178
  'epoch': self._train_epochs,
179
  }
180
  batch = {
181
- 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
182
  self.config.batch_size, -1
183
  ),
184
- 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
185
  self.config.batch_size, -1
186
  ),
187
  'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
 
3
  import time
4
  from functools import partial
5
  import json
6
+ import base64
7
  from multiprocessing import Pool
8
 
9
  import h5py
 
60
  config.add_bos_token = True
61
  config.add_eos_token = True
62
  config.prepend_text = ''
63
+ config.base64_token_dtype = 'i4'
64
  if updates is not None:
65
  config.update(ConfigDict(updates).copy_and_resolve_references())
66
  return config
 
97
  else:
98
  mask = 1.0
99
 
100
+ if field.startswith('<|') and field.endswith('|>'):
101
+ # Special tokens.
102
+ field = field[2:-2]
103
+ if field == 'bos':
104
+ token_buffer.append(self.tokenizer.bos_token_id)
105
+ elif field == 'eos':
106
+ token_buffer.append(self.tokenizer.eos_token_id)
107
+ else:
108
+ # Token ID specified directly.
109
+ token_buffer.append(int(field))
110
  loss_mask_buffer.append(mask)
111
+ elif field.startswith('{') and field.endswith('}'):
112
+ field = field[1:-1]
113
+ # Base64 encoded raw tokens.
114
+ tokens = np.frombuffer(
115
+ base64.b64decode(example[field]),
116
+ dtype=self.config.base64_token_dtype
117
+ ).tolist()
118
+ token_buffer.extend(tokens)
119
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
120
  else:
121
  subfields = field.split('+')
122
  text = self.config.subfield_separator.join(
 
152
  config.always_start_with_bos = False
153
  config.start_seek_loc = 0
154
  config.tokens_count_at_start = 0
155
+ config.batch_token_dtype = 'i4'
156
 
157
  if updates is not None:
158
  config.update(ConfigDict(updates).copy_and_resolve_references())
 
180
  while True:
181
  token_buffer = []
182
  loss_mask_buffer = []
183
+ if not self._eval_dataset:
184
+ self._shuffle()
185
  for index, example in enumerate(self._dataset):
186
  self._index = index
187
  if not self._eval_dataset and self._dataset_loc > index:
 
197
  'epoch': self._train_epochs,
198
  }
199
  batch = {
200
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
201
  self.config.batch_size, -1
202
  ),
203
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
204
  self.config.batch_size, -1
205
  ),
206
  'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
EasyLM/jax_utils.py CHANGED
@@ -400,3 +400,4 @@ def get_weight_decay_mask(exclusions):
400
  def tree_apply(fns, tree):
401
  """ Apply a pytree of functions to the pytree. """
402
  return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
 
 
400
  def tree_apply(fns, tree):
401
  """ Apply a pytree of functions to the pytree. """
402
  return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
403
+
EasyLM/models/gptj/gptj_serve.py CHANGED
@@ -18,7 +18,7 @@ from transformers import GenerationConfig, FlaxLogitsProcessorList
18
  from EasyLM.checkpoint import StreamingCheckpointer
19
  from EasyLM.serving import LMServer
20
  from EasyLM.jax_utils import (
21
- JaxRNG, next_rng, match_partition_rules, tree_apply,
22
  set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
23
  with_sharding_constraint, FlaxTemperatureLogitsWarper
24
  )
@@ -43,12 +43,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
43
  load_checkpoint='',
44
  tokenizer=GPTJConfig.get_tokenizer_config(),
45
  lm_server=LMServer.get_default_config(),
 
46
  )
47
 
48
 
49
  def main(argv):
50
- if FLAGS.initialize_jax_distributed:
51
- jax.distributed.initialize()
52
  set_random_seed(FLAGS.seed)
53
 
54
  prefix_tokenizer = GPTJConfig.get_tokenizer(
 
18
  from EasyLM.checkpoint import StreamingCheckpointer
19
  from EasyLM.serving import LMServer
20
  from EasyLM.jax_utils import (
21
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
22
  set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
23
  with_sharding_constraint, FlaxTemperatureLogitsWarper
24
  )
 
43
  load_checkpoint='',
44
  tokenizer=GPTJConfig.get_tokenizer_config(),
45
  lm_server=LMServer.get_default_config(),
46
+ jax_distributed=JaxDistributedConfig.get_default_config(),
47
  )
48
 
49
 
50
  def main(argv):
51
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
 
52
  set_random_seed(FLAGS.seed)
53
 
54
  prefix_tokenizer = GPTJConfig.get_tokenizer(
EasyLM/models/gptj/gptj_train.py CHANGED
@@ -15,7 +15,7 @@ from EasyLM.data import DatasetFactory
15
  from EasyLM.checkpoint import StreamingCheckpointer
16
  from EasyLM.optimizers import OptimizerFactory
17
  from EasyLM.jax_utils import (
18
- JaxRNG, next_rng, match_partition_rules,
19
  cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
  set_random_seed, average_metrics, get_weight_decay_mask,
21
  make_shard_and_gather_fns, tree_apply
@@ -25,7 +25,6 @@ from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
25
 
26
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
27
  seed=42,
28
- initialize_jax_distributed=False,
29
  mesh_dim='1,-1,1',
30
  dtype='fp32',
31
  total_steps=10000,
@@ -45,13 +44,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
45
  gptj=GPTJConfig.get_default_config(),
46
  logger=mlxu.WandBLogger.get_default_config(),
47
  log_all_worker=False,
 
48
  )
49
 
50
 
51
  def main(argv):
52
- if FLAGS.initialize_jax_distributed:
53
- jax.distributed.initialize()
54
-
55
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
56
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
57
  logger = mlxu.WandBLogger(
 
15
  from EasyLM.checkpoint import StreamingCheckpointer
16
  from EasyLM.optimizers import OptimizerFactory
17
  from EasyLM.jax_utils import (
18
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
19
  cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
  set_random_seed, average_metrics, get_weight_decay_mask,
21
  make_shard_and_gather_fns, tree_apply
 
25
 
26
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
27
  seed=42,
 
28
  mesh_dim='1,-1,1',
29
  dtype='fp32',
30
  total_steps=10000,
 
44
  gptj=GPTJConfig.get_default_config(),
45
  logger=mlxu.WandBLogger.get_default_config(),
46
  log_all_worker=False,
47
+ jax_distributed=JaxDistributedConfig.get_default_config(),
48
  )
49
 
50
 
51
  def main(argv):
52
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
 
 
53
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
54
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
55
  logger = mlxu.WandBLogger(
EasyLM/models/llama/convert_easylm_to_hf.py CHANGED
@@ -77,6 +77,14 @@ LLAMA_STANDARD_CONFIGS = {
77
  'n_heads': 32,
78
  'norm_eps': 1e-6,
79
  },
 
 
 
 
 
 
 
 
80
  '3b': {
81
  'vocab_size': 64256,
82
  'dim': 3200,
@@ -132,7 +140,7 @@ def match_keywords(string, positives, negatives):
132
 
133
  def load_and_convert_checkpoint(path):
134
  _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
135
- flax_params = flatten_dict(flax_params['params']['params']['params'], sep='.')
136
  torch_params = {}
137
  for key, tensor in flax_params.items():
138
  if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
@@ -219,7 +227,6 @@ def write_model(loaded, model_path, model_size):
219
  write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
220
 
221
  config = LlamaConfig(
222
- vocab_size=params["vocab_size"],
223
  hidden_size=dim,
224
  intermediate_size=params["intermediate_size"],
225
  num_attention_heads=params["n_heads"],
@@ -235,12 +242,13 @@ def write_model(loaded, model_path, model_size):
235
 
236
  print("Loading the checkpoint in a Llama model.")
237
  model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
238
- print("Model parameter count", model.num_parameters())
239
  # Avoid saving this as part of the config.
 
240
  del model.config._name_or_path
241
 
242
  print("Saving in the Transformers format.")
243
  model.save_pretrained(model_path)
 
244
  shutil.rmtree(tmp_model_path)
245
 
246
 
@@ -252,21 +260,21 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
252
  "bos_token": {
253
  "content": "<s>",
254
  "lstrip": False,
255
- "normalized": False,
256
  "rstrip": False,
257
  "single_word": False
258
  },
259
  "eos_token": {
260
  "content": "</s>",
261
  "lstrip": False,
262
- "normalized": False,
263
  "rstrip": False,
264
  "single_word": False
265
  },
266
  "unk_token": {
267
  "content": "<unk>",
268
  "lstrip": False,
269
- "normalized": False,
270
  "rstrip": False,
271
  "single_word": False
272
  },
@@ -286,7 +294,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
286
  "__type": "AddedToken",
287
  "content": "<s>",
288
  "lstrip": False,
289
- "normalized": False,
290
  "rstrip": False,
291
  "single_word": False
292
  },
@@ -294,7 +302,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
294
  "__type": "AddedToken",
295
  "content": "</s>",
296
  "lstrip": False,
297
- "normalized": False,
298
  "rstrip": False,
299
  "single_word": False
300
  },
@@ -302,7 +310,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
302
  "__type": "AddedToken",
303
  "content": "<unk>",
304
  "lstrip": False,
305
- "normalized": False,
306
  "rstrip": False,
307
  "single_word": False
308
  },
@@ -313,7 +321,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
313
 
314
 
315
  def main(argv):
316
- assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != "" #and FLAGS.tokenizer_path != ""
317
  assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
318
  # write_tokenizer(
319
  # tokenizer_path=FLAGS.output_dir,
 
77
  'n_heads': 32,
78
  'norm_eps': 1e-6,
79
  },
80
+ '1b': {
81
+ 'vocab_size': 64256,
82
+ 'dim': 2048,
83
+ 'intermediate_size': 5504,
84
+ 'n_layers': 22,
85
+ 'n_heads': 16,
86
+ 'norm_eps': 1e-6,
87
+ },
88
  '3b': {
89
  'vocab_size': 64256,
90
  'dim': 3200,
 
140
 
141
  def load_and_convert_checkpoint(path):
142
  _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
143
+ flax_params = flatten_dict(flax_params['params'], sep='.')
144
  torch_params = {}
145
  for key, tensor in flax_params.items():
146
  if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
 
227
  write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
228
 
229
  config = LlamaConfig(
 
230
  hidden_size=dim,
231
  intermediate_size=params["intermediate_size"],
232
  num_attention_heads=params["n_heads"],
 
242
 
243
  print("Loading the checkpoint in a Llama model.")
244
  model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
 
245
  # Avoid saving this as part of the config.
246
+ print("Model parameter count", model.num_parameters())
247
  del model.config._name_or_path
248
 
249
  print("Saving in the Transformers format.")
250
  model.save_pretrained(model_path)
251
+ model.save_pretrained(model_path, safe_serialization=True)
252
  shutil.rmtree(tmp_model_path)
253
 
254
 
 
260
  "bos_token": {
261
  "content": "<s>",
262
  "lstrip": False,
263
+ "normalized": True,
264
  "rstrip": False,
265
  "single_word": False
266
  },
267
  "eos_token": {
268
  "content": "</s>",
269
  "lstrip": False,
270
+ "normalized": True,
271
  "rstrip": False,
272
  "single_word": False
273
  },
274
  "unk_token": {
275
  "content": "<unk>",
276
  "lstrip": False,
277
+ "normalized": True,
278
  "rstrip": False,
279
  "single_word": False
280
  },
 
294
  "__type": "AddedToken",
295
  "content": "<s>",
296
  "lstrip": False,
297
+ "normalized": True,
298
  "rstrip": False,
299
  "single_word": False
300
  },
 
302
  "__type": "AddedToken",
303
  "content": "</s>",
304
  "lstrip": False,
305
+ "normalized": True,
306
  "rstrip": False,
307
  "single_word": False
308
  },
 
310
  "__type": "AddedToken",
311
  "content": "<unk>",
312
  "lstrip": False,
313
+ "normalized": True,
314
  "rstrip": False,
315
  "single_word": False
316
  },
 
321
 
322
 
323
  def main(argv):
324
+ assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
325
  assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
326
  # write_tokenizer(
327
  # tokenizer_path=FLAGS.output_dir,
EasyLM/models/llama/convert_hf_to_easylm.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python convert_hf_to_easylm.py \
4
+ --checkpoint_dir /path/hf_format_dir/ \
5
+ --output_file /path/easylm_format.stream \
6
+ --model_size 7b \
7
+ --streaming
8
+ """
9
+ import time
10
+ from pathlib import Path
11
+ import argparse
12
+
13
+ import mlxu
14
+ import torch
15
+ import flax
16
+
17
+ from EasyLM.checkpoint import StreamingCheckpointer
18
+
19
+ LLAMA_STANDARD_CONFIGS = {
20
+ '1b': {
21
+ 'dim': 2048,
22
+ 'intermediate_size': 5504,
23
+ 'n_layers': 22,
24
+ 'n_heads': 16,
25
+ 'norm_eps': 1e-6,
26
+ },
27
+ '3b': {
28
+ 'dim': 3200,
29
+ 'intermediate_size': 8640,
30
+ 'n_layers': 26,
31
+ 'n_heads': 32,
32
+ 'norm_eps': 1e-6,
33
+ },
34
+ "7b": {
35
+ "dim": 4096,
36
+ "intermediate_size": 11008,
37
+ "n_layers": 32,
38
+ "n_heads": 32,
39
+ "norm_eps": 1e-6,
40
+ },
41
+ "13b": {
42
+ "dim": 5120,
43
+ "intermediate_size": 13824,
44
+ "n_layers": 40,
45
+ "n_heads": 40,
46
+ "norm_eps": 1e-6,
47
+ },
48
+ "30b": {
49
+ "dim": 6656,
50
+ "intermediate_size": 17920,
51
+ "n_layers": 60,
52
+ "n_heads": 52,
53
+ "norm_eps": 1e-6,
54
+ },
55
+ "65b": {
56
+ "dim": 8192,
57
+ "intermediate_size": 22016,
58
+ "n_layers": 80,
59
+ "n_heads": 64,
60
+ "norm_eps": 1e-5,
61
+ },
62
+ }
63
+
64
+
65
+ def inverse_permute(params, w):
66
+ n_layers = params["n_layers"]
67
+ n_heads = params["n_heads"]
68
+ dim = params["dim"]
69
+ reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
70
+ transposed_w = reshaped_w.transpose(0, 2, 1, 3)
71
+ inverted_w = transposed_w.reshape(dim, dim)
72
+ return inverted_w
73
+
74
+
75
+ def main(args):
76
+ start = time.time()
77
+ params = LLAMA_STANDARD_CONFIGS[args.model_size]
78
+
79
+ ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
80
+ ckpt = {}
81
+ for i, ckpt_path in enumerate(ckpt_paths):
82
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
83
+ for k, v in checkpoint.items():
84
+ if k.startswith("model."):
85
+ k = k[6:]
86
+ ckpt[k] = v
87
+ print(f"Start convert weight to easylm format...")
88
+ jax_weights = {
89
+ "transformer": {
90
+ "wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
91
+ "ln_f": {"kernel": ckpt["norm.weight"].numpy()},
92
+ "h": {
93
+ "%d"
94
+ % (layer): {
95
+ "attention": {
96
+ "wq": {
97
+ "kernel": inverse_permute(
98
+ params,
99
+ ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
100
+ ).transpose()
101
+ },
102
+ "wk": {
103
+ "kernel": inverse_permute(
104
+ params,
105
+ ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
106
+ ).transpose()
107
+ },
108
+ "wv": {
109
+ "kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
110
+ .numpy()
111
+ .transpose()
112
+ },
113
+ "wo": {
114
+ "kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
115
+ .numpy()
116
+ .transpose()
117
+ },
118
+ },
119
+ "feed_forward": {
120
+ "w1": {
121
+ "kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
122
+ .numpy()
123
+ .transpose()
124
+ },
125
+ "w2": {
126
+ "kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
127
+ .numpy()
128
+ .transpose()
129
+ },
130
+ "w3": {
131
+ "kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
132
+ .numpy()
133
+ .transpose()
134
+ },
135
+ },
136
+ "attention_norm": {
137
+ "kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
138
+ },
139
+ "ffn_norm": {
140
+ "kernel": ckpt[
141
+ f"layers.{layer}.post_attention_layernorm.weight"
142
+ ].numpy()
143
+ },
144
+ }
145
+ for layer in range(params["n_layers"])
146
+ },
147
+ },
148
+ "lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
149
+ }
150
+ print(f"Convert weight to easylm format finished...")
151
+ print(f"Start to save...")
152
+
153
+ if args.streaming:
154
+ StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
155
+ else:
156
+ with mlxu.open_file(args.output_file, "wb") as fout:
157
+ fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
158
+
159
+ print(
160
+ f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
161
+ )
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser(description="hf to easylm format script")
166
+
167
+ parser.add_argument(
168
+ "--checkpoint_dir",
169
+ type=str,
170
+ help="Need to be converted model weight dir. it is a dir",
171
+ )
172
+ parser.add_argument(
173
+ "--output_file", type=str, help="Save model weight file path, it is a file."
174
+ )
175
+ parser.add_argument(
176
+ "--model_size",
177
+ type=str,
178
+ default="7b",
179
+ choices=["7b", "13b", "30b", "65b"],
180
+ help="model size",
181
+ )
182
+ parser.add_argument(
183
+ "--streaming",
184
+ action="store_true",
185
+ default=True,
186
+ help="whether is model weight saved stream format",
187
+ )
188
+
189
+ args = parser.parse_args()
190
+
191
+ print(f"checkpoint_dir: {args.checkpoint_dir}")
192
+ print(f"output_file: {args.output_file}")
193
+ print(f"model_size: {args.model_size}")
194
+ print(f"streaming: {args.streaming}")
195
+
196
+ main(args)
EasyLM/models/llama/llama_model.py CHANGED
@@ -3,6 +3,7 @@ from shutil import copyfile
3
  from typing import Any, Dict, List, Optional, Tuple, Union
4
  import json
5
  import tempfile
 
6
 
7
  import numpy as np
8
  import jax
@@ -15,8 +16,10 @@ from flax.linen import combine_masks, make_causal_mask
15
  from flax.linen.attention import dot_product_attention_weights
16
  from flax.traverse_util import flatten_dict, unflatten_dict
17
  from flax.linen import partitioning as nn_partitioning
 
18
 
19
  import sentencepiece as spm
 
20
  from transformers.configuration_utils import PretrainedConfig
21
  from transformers.utils import logging
22
  from transformers.tokenization_utils import PreTrainedTokenizer
@@ -28,6 +31,7 @@ from ml_collections import ConfigDict
28
  from ml_collections.config_dict import config_dict
29
  from mlxu import function_args_to_config, load_pickle, open_file
30
 
 
31
  from EasyLM.jax_utils import (
32
  with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
33
  )
@@ -82,6 +86,18 @@ LLAMA_STANDARD_CONFIGS = {
82
  'use_cache': True,
83
  'tie_word_embeddings': False,
84
  },
 
 
 
 
 
 
 
 
 
 
 
 
85
  '3b': {
86
  'vocab_size': 64256,
87
  'hidden_size': 3200,
@@ -219,7 +235,14 @@ class LLaMAConfig(PretrainedConfig):
219
  embd_pdrop=0.0,
220
  attn_pdrop=0.0,
221
  tie_word_embeddings=False,
222
- gradient_checkpointing='nothing_saveable',
 
 
 
 
 
 
 
223
  fcm_min_ratio=0.0,
224
  fcm_max_ratio=0.0,
225
  **kwargs,
@@ -236,7 +259,14 @@ class LLaMAConfig(PretrainedConfig):
236
  self.resid_pdrop = resid_pdrop
237
  self.embd_pdrop = embd_pdrop
238
  self.attn_pdrop = attn_pdrop
239
- self.gradient_checkpointing = gradient_checkpointing
 
 
 
 
 
 
 
240
  self.fcm_min_ratio = fcm_min_ratio
241
  self.fcm_max_ratio = fcm_max_ratio
242
  super().__init__(
@@ -302,6 +332,7 @@ class LLaMAConfig(PretrainedConfig):
302
  def get_tokenizer_config(updates=None):
303
  config = ConfigDict()
304
  config.vocab_file = ''
 
305
  config.add_bos_token = False
306
  config.add_eos_token = False
307
 
@@ -312,14 +343,23 @@ class LLaMAConfig(PretrainedConfig):
312
  @classmethod
313
  def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
314
  config = cls.get_tokenizer_config(config)
315
- assert config.vocab_file != '', 'vocab_file must be specified'
316
- tokenizer = LLaMATokenizer(
317
- vocab_file=config.vocab_file,
318
- add_bos_token=config.add_bos_token,
319
- add_eos_token=config.add_eos_token,
320
- padding_side=padding_side,
321
- truncation_side=truncation_side,
322
- )
 
 
 
 
 
 
 
 
 
323
  return tokenizer
324
 
325
  @classmethod
@@ -515,53 +555,82 @@ class FlaxLLaMAAttention(nn.Module):
515
 
516
  xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
517
 
518
- query_length, key_length = xq.shape[1], xk.shape[1]
 
 
519
 
520
- if self.has_variable("cache", "cached_key"):
521
- mask_shift = self.variables["cache"]["cache_index"]
522
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
523
- causal_mask = lax.dynamic_slice(
524
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  )
 
526
  else:
527
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
528
-
529
- batch_size = hidden_states.shape[0]
530
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
531
-
532
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
533
- attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
 
 
 
534
 
535
- dropout_rng = None
536
- if not deterministic and self.config.attn_pdrop > 0.0:
537
- dropout_rng = self.make_rng("dropout")
538
 
539
- # During fast autoregressive decoding, we feed one position at a time,
540
- # and cache the keys and values step by step.
541
- if self.has_variable("cache", "cached_key") or init_cache:
542
- xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
543
 
544
- # transform boolean mask into float mask
545
- attention_bias = lax.select(
546
- attention_mask > 0,
547
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
548
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
549
- )
550
 
551
- # usual dot product attention
552
- attn_weights = dot_product_attention_weights(
553
- xq,
554
- xk,
555
- bias=attention_bias,
556
- dropout_rng=dropout_rng,
557
- dropout_rate=self.config.attn_pdrop,
558
- deterministic=deterministic,
559
- dtype=jnp.promote_types(self.dtype, jnp.float32),
560
- precision=self.precision,
561
- )
562
- attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
 
 
 
 
 
 
563
 
564
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
565
  attn_output = self._merge_heads(attn_output)
566
  attn_output = self.wo(attn_output)
567
  attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
@@ -617,13 +686,28 @@ class FlaxLLaMABlock(nn.Module):
617
  precision: Optional[Union[jax.lax.Precision, str]]=None
618
 
619
  def setup(self) -> None:
620
- self.attention = FlaxLLaMAAttention(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  self.config,
622
  dtype=self.dtype,
623
  param_dtype=self.param_dtype,
624
  precision=self.precision,
625
  )
626
- self.feed_forward = FlaxLLaMAMLP(
627
  self.config,
628
  dtype=self.dtype,
629
  param_dtype=self.param_dtype,
@@ -654,20 +738,32 @@ class FlaxLLaMABlock(nn.Module):
654
  ):
655
  attn_outputs = self.attention(
656
  self.attention_norm(hidden_states),
657
- attention_mask=attention_mask,
658
- position_ids=position_ids,
659
- deterministic=deterministic,
660
- init_cache=init_cache,
661
- output_attentions=output_attentions,
662
- fcm_mask=fcm_mask,
663
  )
664
  attn_output = attn_outputs[0]
665
  hidden_states = hidden_states + attn_output
666
 
667
- feed_forward_hidden_states = self.feed_forward(
668
- self.ffn_norm(hidden_states),
669
- deterministic=deterministic,
670
- )
 
 
 
 
 
 
 
 
 
 
 
 
671
  hidden_states = hidden_states + feed_forward_hidden_states
672
 
673
  return (hidden_states,) + attn_outputs[1:]
@@ -828,14 +924,19 @@ class FlaxLLaMABlockCollection(nn.Module):
828
 
829
  def setup(self):
830
  block = FlaxLLaMABlock
831
- if self.config.gradient_checkpointing != '':
832
- FlaxLLaMACheckpointBlock = remat(
833
- block, static_argnums=(3, 4, 5),
834
- policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
835
  )
836
- block = FlaxLLaMACheckpointBlock
837
  self.blocks = [
838
- block(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) for i in range(self.config.num_hidden_layers)
 
 
 
 
 
 
839
  ]
840
 
841
  def __call__(
@@ -862,7 +963,7 @@ class FlaxLLaMABlockCollection(nn.Module):
862
  )
863
  fcm_mask = jax.random.uniform(
864
  self.make_rng('fcm'),
865
- shape=(batch_size, 1, seq_length, seq_length)
866
  ) > fcm_ratio
867
  fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
868
  fcm_mask = fcm_mask.astype('bool')
@@ -1034,7 +1135,7 @@ class FlaxLLaMAForCausalLMModule(nn.Module):
1034
  class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
1035
  module_class = FlaxLLaMAForCausalLMModule
1036
 
1037
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1038
  # initializing the cache
1039
  batch_size, seq_length = input_ids.shape
1040
 
 
3
  from typing import Any, Dict, List, Optional, Tuple, Union
4
  import json
5
  import tempfile
6
+ from functools import partial
7
 
8
  import numpy as np
9
  import jax
 
16
  from flax.linen.attention import dot_product_attention_weights
17
  from flax.traverse_util import flatten_dict, unflatten_dict
18
  from flax.linen import partitioning as nn_partitioning
19
+ import einops
20
 
21
  import sentencepiece as spm
22
+ from transformers import AutoTokenizer
23
  from transformers.configuration_utils import PretrainedConfig
24
  from transformers.utils import logging
25
  from transformers.tokenization_utils import PreTrainedTokenizer
 
31
  from ml_collections.config_dict import config_dict
32
  from mlxu import function_args_to_config, load_pickle, open_file
33
 
34
+ from EasyLM.bpt import blockwise_ffn, blockwise_attn
35
  from EasyLM.jax_utils import (
36
  with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
37
  )
 
86
  'use_cache': True,
87
  'tie_word_embeddings': False,
88
  },
89
+ '1b': {
90
+ 'vocab_size': 64256,
91
+ 'hidden_size': 2048,
92
+ 'intermediate_size': 5504,
93
+ 'num_hidden_layers': 22,
94
+ 'num_attention_heads': 16,
95
+ 'max_sequence_length': 2048,
96
+ 'initializer_range': 0.02,
97
+ 'rms_norm_eps': 1e-6,
98
+ 'use_cache': True,
99
+ 'tie_word_embeddings': False,
100
+ },
101
  '3b': {
102
  'vocab_size': 64256,
103
  'hidden_size': 3200,
 
235
  embd_pdrop=0.0,
236
  attn_pdrop=0.0,
237
  tie_word_embeddings=False,
238
+ remat_block='',
239
+ remat_attention='',
240
+ remat_mlp='',
241
+ scan_attention=False,
242
+ scan_mlp=False,
243
+ scan_query_chunk_size=1024,
244
+ scan_key_chunk_size=1024,
245
+ scan_mlp_chunk_size=1024,
246
  fcm_min_ratio=0.0,
247
  fcm_max_ratio=0.0,
248
  **kwargs,
 
259
  self.resid_pdrop = resid_pdrop
260
  self.embd_pdrop = embd_pdrop
261
  self.attn_pdrop = attn_pdrop
262
+ self.remat_block = remat_block
263
+ self.remat_attention = remat_attention
264
+ self.remat_mlp = remat_mlp
265
+ self.scan_attention = scan_attention
266
+ self.scan_mlp = scan_mlp
267
+ self.scan_query_chunk_size = scan_query_chunk_size
268
+ self.scan_key_chunk_size = scan_key_chunk_size
269
+ self.scan_mlp_chunk_size = scan_mlp_chunk_size
270
  self.fcm_min_ratio = fcm_min_ratio
271
  self.fcm_max_ratio = fcm_max_ratio
272
  super().__init__(
 
332
  def get_tokenizer_config(updates=None):
333
  config = ConfigDict()
334
  config.vocab_file = ''
335
+ config.pretrained_model_name_or_path = ''
336
  config.add_bos_token = False
337
  config.add_eos_token = False
338
 
 
343
  @classmethod
344
  def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
345
  config = cls.get_tokenizer_config(config)
346
+ assert config.vocab_file != '' and config.pretrained_model_name_or_path != '', 'vocab_file or pretrained_model_name_or_path must be specified'
347
+ if config.pretrained_model_name_or_path != '':
348
+ tokenizer = AutoTokenizer.from_pretrained(
349
+ config.pretrained_model_name_or_path,
350
+ add_bos_token=config.add_bos_token,
351
+ add_eos_token=config.add_eos_token,
352
+ padding_side=padding_side,
353
+ truncation_side=truncation_side,
354
+ )
355
+ else:
356
+ tokenizer = LLaMATokenizer(
357
+ vocab_file=config.vocab_file,
358
+ add_bos_token=config.add_bos_token,
359
+ add_eos_token=config.add_eos_token,
360
+ padding_side=padding_side,
361
+ truncation_side=truncation_side,
362
+ )
363
  return tokenizer
364
 
365
  @classmethod
 
555
 
556
  xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
557
 
558
+ dropout_rng = None
559
+ if not deterministic and self.config.attn_pdrop > 0.0:
560
+ dropout_rng = self.make_rng("dropout")
561
 
562
+ if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache):
563
+ # doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory
564
+
565
+ # attention mask without nxn materlization, blockwise_attn will handle the rest
566
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
567
+ # transform boolean mask into float mask
568
+ attention_bias = lax.select(
569
+ attention_mask > 0,
570
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
571
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
572
+ )
573
+ attn_weights = None
574
+ attn_output = blockwise_attn(
575
+ xq,
576
+ xk,
577
+ xv,
578
+ bias=attention_bias,
579
+ deterministic=deterministic,
580
+ dropout_rng=dropout_rng,
581
+ attn_pdrop=self.config.attn_pdrop,
582
+ causal=True,
583
+ query_chunk_size=self.config.scan_query_chunk_size,
584
+ key_chunk_size=self.config.scan_key_chunk_size,
585
+ dtype=self.dtype,
586
+ policy=get_gradient_checkpoint_policy('nothing_saveable'),
587
+ precision=self.precision,
588
+ float32_logits=True,
589
+ prevent_cse=True,
590
  )
591
+ attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
592
  else:
593
+ query_length, key_length = xq.shape[1], xk.shape[1]
594
+
595
+ if self.has_variable("cache", "cached_key"):
596
+ mask_shift = self.variables["cache"]["cache_index"]
597
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
598
+ causal_mask = lax.dynamic_slice(
599
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
600
+ )
601
+ else:
602
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
603
 
604
+ batch_size = hidden_states.shape[0]
605
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
 
606
 
607
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
608
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
 
 
609
 
610
+ # During fast autoregressive decoding, we feed one position at a time,
611
+ # and cache the keys and values step by step.
612
+ if self.has_variable("cache", "cached_key") or init_cache:
613
+ xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
 
 
614
 
615
+ # transform boolean mask into float mask
616
+ attention_bias = lax.select(
617
+ attention_mask > 0,
618
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
619
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
620
+ )
621
+ attn_weights = dot_product_attention_weights(
622
+ xq,
623
+ xk,
624
+ bias=attention_bias,
625
+ dropout_rng=dropout_rng,
626
+ dropout_rate=self.config.attn_pdrop,
627
+ deterministic=deterministic,
628
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
629
+ precision=self.precision,
630
+ )
631
+ attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
632
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
633
 
 
634
  attn_output = self._merge_heads(attn_output)
635
  attn_output = self.wo(attn_output)
636
  attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
 
686
  precision: Optional[Union[jax.lax.Precision, str]]=None
687
 
688
  def setup(self) -> None:
689
+ attention_module = FlaxLLaMAAttention
690
+ mlp_module = FlaxLLaMAMLP
691
+ if self.config.remat_attention != '':
692
+ attention_module = remat(
693
+ FlaxLLaMAAttention, static_argnums=(3, 4, 5),
694
+ policy=get_gradient_checkpoint_policy(self.config.remat_attention),
695
+ prevent_cse=True,
696
+ )
697
+ if self.config.remat_mlp != '':
698
+ mlp_module = remat(
699
+ FlaxLLaMAMLP, static_argnums=(1,),
700
+ policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
701
+ prevent_cse=True,
702
+ )
703
+
704
+ self.attention = attention_module(
705
  self.config,
706
  dtype=self.dtype,
707
  param_dtype=self.param_dtype,
708
  precision=self.precision,
709
  )
710
+ self.feed_forward = mlp_module(
711
  self.config,
712
  dtype=self.dtype,
713
  param_dtype=self.param_dtype,
 
738
  ):
739
  attn_outputs = self.attention(
740
  self.attention_norm(hidden_states),
741
+ attention_mask,
742
+ position_ids,
743
+ deterministic,
744
+ init_cache,
745
+ output_attentions,
746
+ fcm_mask,
747
  )
748
  attn_output = attn_outputs[0]
749
  hidden_states = hidden_states + attn_output
750
 
751
+ feed_forward_input = self.ffn_norm(hidden_states)
752
+
753
+ if self.config.scan_mlp:
754
+ feed_forward_hidden_states = blockwise_ffn(
755
+ self.feed_forward,
756
+ feed_forward_input,
757
+ self.config.scan_mlp_chunk_size,
758
+ deterministic,
759
+ )
760
+ else:
761
+ feed_forward_hidden_states = self.feed_forward(
762
+ feed_forward_input,
763
+ deterministic,
764
+ )
765
+ feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp"))
766
+
767
  hidden_states = hidden_states + feed_forward_hidden_states
768
 
769
  return (hidden_states,) + attn_outputs[1:]
 
924
 
925
  def setup(self):
926
  block = FlaxLLaMABlock
927
+ if self.config.remat_block != '':
928
+ block = remat(
929
+ FlaxLLaMABlock, static_argnums=(3, 4, 5),
930
+ policy=get_gradient_checkpoint_policy(self.config.remat_block)
931
  )
 
932
  self.blocks = [
933
+ block(
934
+ self.config,
935
+ name=str(i),
936
+ dtype=self.dtype,
937
+ param_dtype=self.param_dtype,
938
+ precision=self.precision
939
+ ) for i in range(self.config.num_hidden_layers)
940
  ]
941
 
942
  def __call__(
 
963
  )
964
  fcm_mask = jax.random.uniform(
965
  self.make_rng('fcm'),
966
+ shape=(batch_size, 1, 1, seq_length)
967
  ) > fcm_ratio
968
  fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
969
  fcm_mask = fcm_mask.astype('bool')
 
1135
  class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
1136
  module_class = FlaxLLaMAForCausalLMModule
1137
 
1138
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
1139
  # initializing the cache
1140
  batch_size, seq_length = input_ids.shape
1141
 
EasyLM/models/llama/llama_serve.py CHANGED
@@ -14,7 +14,7 @@ from transformers import GenerationConfig, FlaxLogitsProcessorList
14
  from EasyLM.checkpoint import StreamingCheckpointer
15
  from EasyLM.serving import LMServer
16
  from EasyLM.jax_utils import (
17
- JaxRNG, next_rng, match_partition_rules, tree_apply,
18
  set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
19
  with_sharding_constraint, FlaxTemperatureLogitsWarper
20
  )
@@ -37,12 +37,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
37
  load_checkpoint='',
38
  tokenizer=LLaMAConfig.get_tokenizer_config(),
39
  lm_server=LMServer.get_default_config(),
 
40
  )
41
 
42
 
43
  def main(argv):
44
- if FLAGS.initialize_jax_distributed:
45
- jax.distributed.initialize()
46
  set_random_seed(FLAGS.seed)
47
 
48
  prefix_tokenizer = LLaMAConfig.get_tokenizer(
 
14
  from EasyLM.checkpoint import StreamingCheckpointer
15
  from EasyLM.serving import LMServer
16
  from EasyLM.jax_utils import (
17
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
18
  set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
19
  with_sharding_constraint, FlaxTemperatureLogitsWarper
20
  )
 
37
  load_checkpoint='',
38
  tokenizer=LLaMAConfig.get_tokenizer_config(),
39
  lm_server=LMServer.get_default_config(),
40
+ jax_distributed=JaxDistributedConfig.get_default_config(),
41
  )
42
 
43
 
44
  def main(argv):
45
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
 
46
  set_random_seed(FLAGS.seed)
47
 
48
  prefix_tokenizer = LLaMAConfig.get_tokenizer(
EasyLM/models/llama/llama_train.py CHANGED
@@ -15,7 +15,7 @@ from EasyLM.data import DatasetFactory
15
  from EasyLM.checkpoint import StreamingCheckpointer
16
  from EasyLM.optimizers import OptimizerFactory
17
  from EasyLM.jax_utils import (
18
- JaxRNG, next_rng, match_partition_rules,
19
  cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
  set_random_seed, average_metrics, get_weight_decay_mask,
21
  make_shard_and_gather_fns, with_sharding_constraint,
@@ -27,10 +27,8 @@ from EasyLM.models.llama.llama_model import (
27
 
28
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
29
  seed=42,
30
- initialize_jax_distributed=False,
31
  mesh_dim='1,-1,1',
32
  dtype='fp32',
33
- param_dtype='fp32',
34
  total_steps=10000,
35
  load_llama_config='',
36
  update_llama_config='',
@@ -48,13 +46,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
48
  llama=LLaMAConfig.get_default_config(),
49
  logger=mlxu.WandBLogger.get_default_config(),
50
  log_all_worker=False,
 
51
  )
52
 
53
 
54
  def main(argv):
55
- if FLAGS.initialize_jax_distributed:
56
- jax.distributed.initialize()
57
-
58
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
59
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
60
  logger = mlxu.WandBLogger(
@@ -66,7 +63,6 @@ def main(argv):
66
 
67
  tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
68
  dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
69
-
70
  if FLAGS.load_dataset_state != '':
71
  dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
72
 
@@ -90,6 +86,7 @@ def main(argv):
90
  eos_token_id=dataset.tokenizer.eos_token_id,
91
  ))
92
  if llama_config.vocab_size < dataset.vocab_size:
 
93
  llama_config.update(dict(vocab_size=dataset.vocab_size))
94
 
95
  model = FlaxLLaMAForCausalLMModule(
@@ -250,7 +247,7 @@ def main(argv):
250
  metrics.update(average_metrics(eval_metric_list))
251
 
252
  if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
253
- log_metrics = {"step": step}
254
  log_metrics.update(metrics)
255
  log_metrics.update(dataset_metrics)
256
  log_metrics = jax.device_get(log_metrics)
 
15
  from EasyLM.checkpoint import StreamingCheckpointer
16
  from EasyLM.optimizers import OptimizerFactory
17
  from EasyLM.jax_utils import (
18
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
19
  cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
  set_random_seed, average_metrics, get_weight_decay_mask,
21
  make_shard_and_gather_fns, with_sharding_constraint,
 
27
 
28
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
29
  seed=42,
 
30
  mesh_dim='1,-1,1',
31
  dtype='fp32',
 
32
  total_steps=10000,
33
  load_llama_config='',
34
  update_llama_config='',
 
46
  llama=LLaMAConfig.get_default_config(),
47
  logger=mlxu.WandBLogger.get_default_config(),
48
  log_all_worker=False,
49
+ jax_distributed=JaxDistributedConfig.get_default_config(),
50
  )
51
 
52
 
53
  def main(argv):
54
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
 
 
55
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
56
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
57
  logger = mlxu.WandBLogger(
 
63
 
64
  tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
65
  dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
 
66
  if FLAGS.load_dataset_state != '':
67
  dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
68
 
 
86
  eos_token_id=dataset.tokenizer.eos_token_id,
87
  ))
88
  if llama_config.vocab_size < dataset.vocab_size:
89
+ print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
90
  llama_config.update(dict(vocab_size=dataset.vocab_size))
91
 
92
  model = FlaxLLaMAForCausalLMModule(
 
247
  metrics.update(average_metrics(eval_metric_list))
248
 
249
  if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
250
+ log_metrics = {"step": step + 1}
251
  log_metrics.update(metrics)
252
  log_metrics.update(dataset_metrics)
253
  log_metrics = jax.device_get(log_metrics)
EasyLM/models/roberta/roberta_train.py CHANGED
@@ -17,7 +17,7 @@ from EasyLM.data import DatasetFactory
17
  from EasyLM.checkpoint import StreamingCheckpointer
18
  from EasyLM.optimizers import OptimizerFactory
19
  from EasyLM.jax_utils import (
20
- JaxRNG, next_rng, match_partition_rules, get_float_dtype_by_name,
21
  cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
22
  set_random_seed, average_metrics, get_weight_decay_mask,
23
  make_shard_and_gather_fns, tree_apply
@@ -29,7 +29,6 @@ from EasyLM.models.roberta.roberta_model import (
29
 
30
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
  seed=42,
32
- initialize_jax_distributed=False,
33
  mesh_dim='-1,1,1',
34
  dtype='fp32',
35
  mask_token_probability=0.15,
@@ -50,13 +49,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
50
  roberta=RobertaConfig.get_default_config(),
51
  logger=mlxu.WandBLogger.get_default_config(),
52
  log_all_worker=False,
 
53
  )
54
 
55
 
56
  def main(argv):
57
- if FLAGS.initialize_jax_distributed:
58
- jax.distributed.initialize()
59
-
60
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
61
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
62
  logger = mlxu.WandBLogger(
 
17
  from EasyLM.checkpoint import StreamingCheckpointer
18
  from EasyLM.optimizers import OptimizerFactory
19
  from EasyLM.jax_utils import (
20
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
21
  cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
22
  set_random_seed, average_metrics, get_weight_decay_mask,
23
  make_shard_and_gather_fns, tree_apply
 
29
 
30
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
  seed=42,
 
32
  mesh_dim='-1,1,1',
33
  dtype='fp32',
34
  mask_token_probability=0.15,
 
49
  roberta=RobertaConfig.get_default_config(),
50
  logger=mlxu.WandBLogger.get_default_config(),
51
  log_all_worker=False,
52
+ jax_distributed=JaxDistributedConfig.get_default_config(),
53
  )
54
 
55
 
56
  def main(argv):
57
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
 
 
58
  variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
59
  flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
60
  logger = mlxu.WandBLogger(
EasyLM/optimizers.py CHANGED
@@ -193,7 +193,6 @@ class AdamWOptimizerFactory(object):
193
 
194
  return optimizer, optimizer_info
195
 
196
-
197
  class LionOptimizerFactory(object):
198
  """ Lion optimizer with cosine schedule. """
199
 
@@ -250,7 +249,7 @@ class LionOptimizerFactory(object):
250
 
251
 
252
  class OptaxScheduledWeightDecayState(NamedTuple):
253
- count: jnp.DeviceArray
254
 
255
 
256
  def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
 
193
 
194
  return optimizer, optimizer_info
195
 
 
196
  class LionOptimizerFactory(object):
197
  """ Lion optimizer with cosine schedule. """
198
 
 
249
 
250
 
251
  class OptaxScheduledWeightDecayState(NamedTuple):
252
+ count: jax.Array
253
 
254
 
255
  def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
EasyLM/scripts/benchmark_attention.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from time import time
3
+ import os
4
+ import numpy as np
5
+ import jax
6
+ import jax.flatten_util
7
+ import jax.numpy as jnp
8
+ import mlxu
9
+ from EasyLM.bpt import blockwise_attn
10
+ from EasyLM.jax_utils import (
11
+ get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
12
+ )
13
+
14
+
15
+ FLAGS, _ = mlxu.define_flags_with_default(
16
+ seed=42,
17
+ dtype='fp32',
18
+ embed_dim=2048,
19
+ n_heads=16,
20
+ ref_attn_seq_len=2048,
21
+ eff_attn_seq_len=16384,
22
+ batch_size=1,
23
+ query_chunk_size=2048,
24
+ key_chunk_size=2048,
25
+ warmup_steps=40,
26
+ steps=200,
27
+ )
28
+
29
+
30
+ def main(argv):
31
+
32
+ def random_kqv(rng_key, seq_len):
33
+ rng_generator = JaxRNG(rng_key)
34
+ kqv = []
35
+ for i in range(3):
36
+ kqv.append(
37
+ jax.random.normal(
38
+ rng_generator(),
39
+ (FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
40
+ dtype=get_float_dtype_by_name(FLAGS.dtype)
41
+ )
42
+ )
43
+ return tuple(kqv)
44
+
45
+ def reference_attn(query, key, value):
46
+ dtype = get_float_dtype_by_name(FLAGS.dtype)
47
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
48
+ logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
49
+ mask_value = jnp.finfo(logits.dtype).min
50
+ _, q_seq_len, _, _ = query.shape
51
+ _, kv_seq_len, _, _ = key.shape
52
+ mask_shape = (q_seq_len, kv_seq_len)
53
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
54
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
55
+ causal_mask = (row_ids < col_ids)[None, None, :, :]
56
+ logits = logits + jnp.where(causal_mask, mask_value, 0.0)
57
+ weights = jax.nn.softmax(logits, axis=-1)
58
+ out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
59
+ return out
60
+
61
+ def efficient_attention(query, key, value):
62
+ dtype = get_float_dtype_by_name(FLAGS.dtype)
63
+ return blockwise_attn(
64
+ query, key, value,
65
+ bias=None,
66
+ deterministic=True,
67
+ dropout_rng=None,
68
+ attn_pdrop=0.0,
69
+ causal=True,
70
+ query_chunk_size=FLAGS.query_chunk_size,
71
+ key_chunk_size=FLAGS.key_chunk_size,
72
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
73
+ policy=jax.checkpoint_policies.nothing_saveable(),
74
+ precision=None,
75
+ float32_logits=True,
76
+ prevent_cse=True,
77
+ )
78
+
79
+
80
+ @partial(jax.jit, static_argnums=(1,))
81
+ def reference_attn_forward_backward(rng_key, seq_len):
82
+ @partial(jax.grad, argnums=(0, 1, 2))
83
+ @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
84
+ def grad_fn(query, key, value):
85
+ out = reference_attn(query, key, value)
86
+ return jnp.mean(out)
87
+
88
+ query, key, value = random_kqv(rng_key, seq_len)
89
+ return jax.flatten_util.ravel_pytree(
90
+ grad_fn(query, key, value)[1]
91
+ )[0].mean()
92
+
93
+ @partial(jax.jit, static_argnums=(1,))
94
+ def efficient_attn_forward_backward(rng_key, seq_len):
95
+ @partial(jax.grad, argnums=(0, 1, 2))
96
+ def grad_fn(query, key, value):
97
+ out = efficient_attention(query, key, value)
98
+ return jnp.mean(out)
99
+
100
+ query, key, value = random_kqv(rng_key, seq_len)
101
+ return jax.flatten_util.ravel_pytree(
102
+ grad_fn(query, key, value)[1]
103
+ )[0].mean()
104
+
105
+
106
+ set_random_seed(FLAGS.seed)
107
+
108
+ jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
109
+ jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
110
+
111
+ all_results = []
112
+ for i in range(FLAGS.warmup_steps):
113
+ all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
114
+ jax.block_until_ready(all_results)
115
+
116
+ start_time = time()
117
+ all_results = []
118
+ for i in range(FLAGS.steps):
119
+ all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
120
+
121
+ jax.block_until_ready(all_results)
122
+ elapsed_time_ref_attn = time() - start_time
123
+ print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
124
+
125
+
126
+ all_results = []
127
+ for i in range(FLAGS.warmup_steps):
128
+ all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
129
+ jax.block_until_ready(all_results)
130
+
131
+
132
+ start_time = time()
133
+ all_results = []
134
+ for i in range(FLAGS.steps):
135
+ all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
136
+
137
+ jax.block_until_ready(all_results)
138
+ elapsed_time_efficient_attn = time() - start_time
139
+ print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
140
+
141
+ flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
142
+ efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
143
+ print(f'Efficiency: {efficiency:.3f}')
144
+
145
+
146
+ if __name__ == '__main__':
147
+ mlxu.run(main)
148
+
149
+
150
+
EasyLM/scripts/lm_eval_harness.py CHANGED
@@ -20,6 +20,8 @@ from EasyLM.serving import LMClient
20
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
21
  tasks='wsc,piqa,winogrande,openbookqa,logiqa',
22
  shots=0,
 
 
23
  lm_client=LMClient.get_default_config(),
24
  logger=mlxu.WandBLogger.get_default_config(),
25
  )
@@ -51,7 +53,9 @@ def main(argv):
51
  model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
52
  task_list = FLAGS.tasks.split(',')
53
  results = evaluator.evaluate(
54
- model, tasks.get_task_dict(task_list), False, FLAGS.shots, None
 
 
55
  )
56
  logger.log(flatten_dict(results['results'], sep='/'))
57
  pprint.pprint(results)
 
20
  FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
21
  tasks='wsc,piqa,winogrande,openbookqa,logiqa',
22
  shots=0,
23
+ limit=0,
24
+ write_out=False,
25
  lm_client=LMClient.get_default_config(),
26
  logger=mlxu.WandBLogger.get_default_config(),
27
  )
 
53
  model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
54
  task_list = FLAGS.tasks.split(',')
55
  results = evaluator.evaluate(
56
+ model, tasks.get_task_dict(task_list), False, FLAGS.shots,
57
+ limit=None if FLAGS.limit <= 0 else FLAGS.limit,
58
+ write_out=FLAGS.write_out,
59
  )
60
  logger.log(flatten_dict(results['results'], sep='/'))
61
  pprint.pprint(results)