aapot commited on
Commit
5a63fc6
1 Parent(s): 64db1e7

Add easylm training code

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
EasyLM/__init__.py ADDED
File without changes
EasyLM/bpt.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
3
+ Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
4
+ """
5
+
6
+ import functools
7
+ from typing import NamedTuple
8
+
9
+ import flax.linen as nn
10
+ import jax
11
+ import jax.lax as lax
12
+ import jax.numpy as jnp
13
+ from einops import rearrange
14
+
15
+ """
16
+ Computing ffn blockwise without materializing the large hidden tensor, training
17
+ 4x longer sequences than the memory-efficient transformer.
18
+ Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
19
+ """
20
+ def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
21
+ # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
22
+ # inputs: (batch, seq_len, dim)
23
+ # chunk_size: the chunk size to split the sequence
24
+ inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
25
+ def scan_ffn(remat_ffn, carry, hidden_states):
26
+ outputs = remat_ffn(hidden_states, deterministic=deterministic)
27
+ return carry, outputs
28
+ scan_axis = inputs.ndim - 2
29
+ _, res = nn.scan(
30
+ scan_ffn,
31
+ variable_broadcast="params",
32
+ split_rngs={"params": False, "dropout": True},
33
+ in_axes=scan_axis,
34
+ out_axes=scan_axis,
35
+ )(remat_ffn, None, inputs)
36
+ res = rearrange(res, 'b c n d -> b (c n) d')
37
+ return res
38
+
39
+
40
+ """
41
+ Compute attention blockwise without materializing the full attention matrix,
42
+ initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
43
+ flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
44
+ efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
45
+ Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
46
+ longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
47
+ """
48
+ def blockwise_attn(
49
+ query, key, value,
50
+ bias=None,
51
+ deterministic=True,
52
+ dropout_rng=None,
53
+ attn_pdrop=0.0,
54
+ causal=True,
55
+ query_chunk_size=2048,
56
+ key_chunk_size=2048,
57
+ dtype=jnp.float32,
58
+ policy=jax.checkpoint_policies.nothing_saveable(),
59
+ precision=None,
60
+ float32_logits=True,
61
+ prevent_cse=True,
62
+ ):
63
+ # query, key, value: (batch, seq_len, num_heads, dim_per_head)
64
+ # bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
65
+ # causal: whether to use causal mask
66
+ # policy: one of jax.checkpoint_policies
67
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
68
+ if float32_logits:
69
+ query = query.astype(jnp.float32)
70
+ key = key.astype(jnp.float32)
71
+
72
+ batch, q_len, num_heads, dim_per_head = query.shape
73
+ batch, kv_len, num_heads, dim_per_head = key.shape
74
+ batch, kv_len, num_heads, dim_per_head = value.shape
75
+
76
+ num_q = q_len // query_chunk_size
77
+ num_kv = kv_len // key_chunk_size
78
+ query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
79
+ key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
80
+ value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
81
+
82
+ query = jnp.moveaxis(query, 1, 0)
83
+ key = jnp.moveaxis(key, 1, 0)
84
+ value = jnp.moveaxis(value, 1, 0)
85
+
86
+ if bias is not None:
87
+ for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
88
+ assert bias_dim == 1 or bias_dim == broadcast_dim
89
+ if not deterministic and attn_pdrop > 0.0:
90
+ attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
91
+ attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
92
+ else:
93
+ attn_dropout = None
94
+
95
+ _chunk_bias_fn = functools.partial(
96
+ _chunk_attention_bias,
97
+ query_chunk_size, key_chunk_size, bias, deterministic,
98
+ attn_dropout, attn_pdrop, causal, dtype)
99
+
100
+ def scan_attention(args):
101
+ query_chunk, query_chunk_idx = args
102
+
103
+ @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
104
+ def scan_kv_block(carry, args):
105
+ key_chunk, value_chunk, key_chunk_idx = args
106
+ (numerator, denominator, prev_max_score) = carry
107
+ attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
108
+ bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
109
+ bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
110
+ attn_weights = attn_weights + bias_chunk
111
+
112
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
113
+ max_score = jnp.maximum(prev_max_score, max_score)
114
+ max_score = jax.lax.stop_gradient(max_score)
115
+ exp_weights = jnp.exp(attn_weights - max_score)
116
+ exp_values = jnp.einsum(
117
+ 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
118
+ )
119
+ correction = jnp.exp(prev_max_score - max_score)
120
+ numerator = numerator * correction + exp_values
121
+ denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
122
+ return Carry(numerator, denominator, max_score), None
123
+
124
+ def skip_upper_half(carry, args):
125
+ key_chunk, value_chunk, key_chunk_idx = args
126
+ skip_block = jnp.array(False)
127
+ if causal:
128
+ skip_block = query_chunk_idx < key_chunk_idx
129
+ return jax.lax.cond(
130
+ skip_block,
131
+ lambda carry, args: (carry, None),
132
+ scan_kv_block,
133
+ carry,
134
+ args,
135
+ )
136
+
137
+ init_carry = Carry(
138
+ jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
139
+ jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
140
+ (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
141
+ )
142
+ (numerator, denominator, max_score), _ = lax.scan(
143
+ skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
144
+ )
145
+ outputs = (numerator / denominator).astype(dtype)
146
+ return outputs
147
+
148
+ _, res = lax.scan(
149
+ lambda _, x: ((), scan_attention(x)),
150
+ (), xs=(query, jnp.arange(0, num_q))
151
+ )
152
+ res = rearrange(res, 'n b c h d -> b (n c) h d')
153
+ return res
154
+
155
+
156
+ class Carry(NamedTuple):
157
+ numerator: jax.Array
158
+ denominator: jax.Array
159
+ max_so_far: jax.Array
160
+
161
+
162
+ def _chunk_attention_bias(query_chunk_size, key_chunk_size,
163
+ bias, deterministic, attn_dropout, attn_pdrop, causal,
164
+ dtype, query_chunk_idx, key_chunk_idx):
165
+ query_offset = query_chunk_idx * query_chunk_size
166
+ key_offset = key_chunk_idx * key_chunk_size
167
+ chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
168
+ if bias is not None:
169
+ chunk_bias = lax.dynamic_slice(
170
+ bias,
171
+ start_indices=(0, 0, query_offset, key_offset),
172
+ slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
173
+ )
174
+
175
+ if causal:
176
+ query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
177
+ key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
178
+ offset = query_offset - key_offset
179
+ query_idx += offset
180
+ causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
181
+ chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
182
+
183
+ if not deterministic and attn_pdrop > 0.0:
184
+ attn_dropout_slice = lax.dynamic_slice(
185
+ attn_dropout,
186
+ start_indices=(0, 0, query_offset, key_offset),
187
+ slice_sizes=(
188
+ *attn_dropout.shape[:2],
189
+ min(attn_dropout.shape[-2], query_chunk_size),
190
+ min(attn_dropout.shape[-1], key_chunk_size),
191
+ ),
192
+ )
193
+ chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
194
+ return chunk_bias.astype(dtype)
195
+
196
+
197
+ if __name__ == '__main__':
198
+ # test
199
+ def reference_attn(query, key, value, causal, dtype):
200
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
201
+ logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
202
+ if causal:
203
+ mask_value = jnp.finfo(logits.dtype).min
204
+ _, q_seq_len, _, _ = query.shape
205
+ _, kv_seq_len, _, _ = key.shape
206
+ mask_shape = (q_seq_len, kv_seq_len)
207
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
208
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
209
+ causal_mask = (row_ids < col_ids)[None, None, :, :]
210
+ logits = logits + jnp.where(causal_mask, mask_value, 0.0)
211
+ weights = jax.nn.softmax(logits, axis=-1)
212
+ out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
213
+ return out
214
+
215
+ # random inputs
216
+ shape = (1, 32, 8, 64)
217
+ query = jax.random.normal(jax.random.PRNGKey(0), shape)
218
+ key = jax.random.normal(jax.random.PRNGKey(1), shape)
219
+ value = jax.random.normal(jax.random.PRNGKey(2), shape)
220
+
221
+ causal = True
222
+ chunk_size = 4
223
+ policy = jax.checkpoint_policies.nothing_saveable()
224
+
225
+ blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
226
+ reference = reference_attn(query, key, value, causal, 'float32')
227
+
228
+ assert jnp.allclose(reference, blockwise, atol=1e-6)
EasyLM/checkpoint.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from ml_collections import ConfigDict
4
+ import mlxu
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax
8
+ from flax.serialization import (
9
+ from_bytes, to_bytes, to_state_dict, from_state_dict
10
+ )
11
+ from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
12
+ import msgpack
13
+
14
+ from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
15
+
16
+
17
+ class StreamingCheckpointer(object):
18
+ """ Custom msgpack checkpointer that saves large train states by serializing
19
+ and saving tensors one by one in a streaming fashion. Avoids running
20
+ out of memory or local TPU disk with default flax checkpointer.
21
+ """
22
+
23
+ @staticmethod
24
+ def get_default_config(updates=None):
25
+ config = ConfigDict()
26
+ config.float_dtype = 'bf16'
27
+ config.save_optimizer_state = False
28
+
29
+ if updates is not None:
30
+ config.update(ConfigDict(updates).copy_and_resolve_references())
31
+ return config
32
+
33
+ def __init__(self, config, checkpoint_dir, enable=True):
34
+ self.config = self.get_default_config(config)
35
+ self.checkpoint_dir = checkpoint_dir
36
+ self.enable = enable
37
+
38
+ def save_checkpoint(self, train_state, filename, gather_fns=None):
39
+ if self.enable:
40
+ path = os.path.join(self.checkpoint_dir, filename)
41
+ else:
42
+ path = '/dev/null'
43
+ self.save_train_state_to_file(
44
+ train_state, path, gather_fns, self.config.float_dtype
45
+ )
46
+
47
+ @staticmethod
48
+ def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
49
+ train_state = to_state_dict(train_state)
50
+ packer = msgpack.Packer()
51
+ flattend_train_state = flatten_dict(train_state)
52
+ if gather_fns is not None:
53
+ gather_fns = flatten_dict(to_state_dict(gather_fns))
54
+
55
+ with mlxu.open_file(path, "wb") as fout:
56
+ for key, value in flattend_train_state.items():
57
+ if gather_fns is not None:
58
+ value = gather_fns[key](value)
59
+ value = float_tensor_to_dtype(value, float_dtype)
60
+ fout.write(packer.pack((key, to_bytes(value))))
61
+
62
+ def save_pickle(self, obj, filename):
63
+ if self.enable:
64
+ path = os.path.join(self.checkpoint_dir, filename)
65
+ else:
66
+ path = '/dev/null'
67
+ mlxu.save_pickle(obj, path)
68
+
69
+ def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
70
+ step = int(jax.device_get(train_state.step))
71
+ if self.config.save_optimizer_state:
72
+ checkpoint_state = train_state
73
+ checkpoint_name = 'streaming_train_state'
74
+ checkpoint_gather_fns = gather_fns
75
+ else:
76
+ checkpoint_state = train_state.params['params']
77
+ checkpoint_name = 'streaming_params'
78
+ checkpoint_gather_fns = gather_fns.params['params']
79
+
80
+ if milestone:
81
+ # Save a milestone checkpoint that will not be overwritten
82
+ self.save_pickle(metadata, f'metadata_{step}.pkl')
83
+ self.save_pickle(dataset, f'dataset_{step}.pkl')
84
+ self.save_checkpoint(
85
+ checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
86
+ )
87
+ else:
88
+ # Save a normal checkpoint that can be overwritten
89
+ self.save_pickle(metadata, 'metadata.pkl')
90
+ self.save_pickle(dataset, 'dataset.pkl')
91
+ self.save_checkpoint(
92
+ checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
93
+ )
94
+
95
+ @staticmethod
96
+ def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
97
+ if shard_fns is not None:
98
+ shard_fns = flatten_dict(
99
+ to_state_dict(shard_fns)
100
+ )
101
+ if remove_dict_prefix is not None:
102
+ remove_dict_prefix = tuple(remove_dict_prefix)
103
+ flattend_train_state = {}
104
+ with mlxu.open_file(path) as fin:
105
+ # 83886080 bytes = 80 MB, which is 16 blocks on GCS
106
+ unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
107
+ for key, value in unpacker:
108
+ key = tuple(key)
109
+ if remove_dict_prefix is not None:
110
+ if key[:len(remove_dict_prefix)] == remove_dict_prefix:
111
+ key = key[len(remove_dict_prefix):]
112
+ else:
113
+ continue
114
+
115
+ tensor = from_bytes(None, value)
116
+ if shard_fns is not None:
117
+ tensor = shard_fns[key](tensor)
118
+ flattend_train_state[key] = tensor
119
+
120
+ if target is not None:
121
+ flattened_target = flatten_dict(
122
+ to_state_dict(target), keep_empty_nodes=True
123
+ )
124
+ for key, value in flattened_target.items():
125
+ if key not in flattend_train_state and value == empty_node:
126
+ flattend_train_state[key] = value
127
+
128
+ train_state = unflatten_dict(flattend_train_state)
129
+ if target is None:
130
+ return train_state
131
+
132
+ return from_state_dict(target, train_state)
133
+
134
+ @staticmethod
135
+ def load_flax_checkpoint(path, target=None, shard_fns=None):
136
+ """ Load a standard flax checkpoint that's not saved with the
137
+ msgpack streaming format.
138
+ """
139
+ with mlxu.open_file(path, "rb") as fin:
140
+ encoded_bytes = fin.read()
141
+
142
+ state_dict = flax.serialization.msgpack_restore(encoded_bytes)
143
+ if shard_fns is not None:
144
+ shard_fns = to_state_dict(shard_fns)
145
+ state_dict = tree_apply(shard_fns, state_dict)
146
+
147
+ if target is None:
148
+ return state_dict
149
+ return from_state_dict(target, state_dict)
150
+
151
+ @classmethod
152
+ def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
153
+ trainstate_shard_fns=None,
154
+ disallow_trainstate=False):
155
+ if trainstate_target is not None:
156
+ params_target = trainstate_target.params['params']
157
+ else:
158
+ params_target = None
159
+
160
+ if trainstate_shard_fns is not None:
161
+ params_shard_fns = trainstate_shard_fns.params['params']
162
+ else:
163
+ params_shard_fns = None
164
+
165
+ load_type, load_path = load_from.split('::', 1)
166
+ if disallow_trainstate:
167
+ assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
168
+ train_state = None
169
+ restored_params = None
170
+ if load_type == 'trainstate':
171
+ # Load the entire train state in the streaming format
172
+ train_state = cls.load_checkpoint(
173
+ path=load_path,
174
+ target=trainstate_target,
175
+ shard_fns=trainstate_shard_fns,
176
+ )
177
+ elif load_type == 'trainstate_params':
178
+ # Load the params part of the train state in the streaming format
179
+ restored_params = cls.load_checkpoint(
180
+ path=load_path,
181
+ target=params_target,
182
+ shard_fns=params_shard_fns,
183
+ remove_dict_prefix=('params', 'params'),
184
+ )
185
+ restored_params = flax.core.frozen_dict.freeze(
186
+ {'params': restored_params}
187
+ )
188
+ elif load_type == 'params':
189
+ # Load the params in the streaming format
190
+ restored_params = cls.load_checkpoint(
191
+ path=load_path,
192
+ target=params_target,
193
+ shard_fns=params_shard_fns,
194
+ )
195
+ restored_params = flax.core.frozen_dict.freeze(
196
+ {'params': restored_params}
197
+ )
198
+ elif load_type == 'flax_params':
199
+ # Load the params in the standard flax format (non-streaming)
200
+ # This requires the entire params to fit in memory
201
+ restored_params = cls.load_flax_checkpoint(
202
+ path=load_path,
203
+ target=params_target,
204
+ shard_fns=params_shard_fns
205
+ )
206
+ restored_params = flax.core.frozen_dict.freeze(
207
+ {'params': restored_params}
208
+ )
209
+ else:
210
+ raise ValueError(f'Invalid load_from type: {load_type}')
211
+
212
+ return train_state, restored_params
EasyLM/data.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ import time
4
+ from functools import partial
5
+ import json
6
+ import base64
7
+ from multiprocessing import Pool
8
+
9
+ import h5py
10
+ import mlxu
11
+ from ml_collections.config_dict import config_dict
12
+ from ml_collections import ConfigDict
13
+ from tqdm import tqdm, trange
14
+ import numpy as np
15
+
16
+ from datasets import load_dataset, load_from_disk
17
+
18
+
19
+ class DatasetFactory(object):
20
+ """ Datset builder class. """
21
+
22
+ @staticmethod
23
+ def get_default_config(updates=None):
24
+ config = ConfigDict()
25
+ config.type = 'huggingface'
26
+ config.text_processor = TextProcessor.get_default_config()
27
+ config.huggingface_dataset = HuggingfaceDataset.get_default_config()
28
+ config.json_dataset = JsonDataset.get_default_config()
29
+
30
+ if updates is not None:
31
+ config.update(ConfigDict(updates).copy_and_resolve_references())
32
+ return config
33
+
34
+ @classmethod
35
+ def load_dataset(cls, config, tokenizer, **kwargs):
36
+ config = cls.get_default_config(config)
37
+ text_processor = TextProcessor(config.text_processor, tokenizer)
38
+ if config.type == 'huggingface':
39
+ return HuggingfaceDataset(
40
+ config.huggingface_dataset, tokenizer, text_processor, **kwargs
41
+ )
42
+ elif config.type == 'json':
43
+ return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
44
+ else:
45
+ raise ValueError(f'Unknown dataset type: {config.type}')
46
+
47
+ def __init__(self):
48
+ raise ValueError('DatasetFactory is a static class and should not be instantiated.')
49
+
50
+
51
+ class TextProcessor(object):
52
+ """ Example processor that converts a dictionary of texts into tokens. """
53
+
54
+ @staticmethod
55
+ def get_default_config(updates=None):
56
+ config = ConfigDict()
57
+ config.fields_from_example = ''
58
+ config.fields = ''
59
+ config.subfield_separator = ' '
60
+ config.add_bos_token = True
61
+ config.add_eos_token = True
62
+ config.prepend_text = ''
63
+ config.base64_token_dtype = 'i4'
64
+ if updates is not None:
65
+ config.update(ConfigDict(updates).copy_and_resolve_references())
66
+ return config
67
+
68
+ def __init__(self, config, tokenizer):
69
+ self.config = self.get_default_config(config)
70
+ assert self.config.fields != '' or self.config.fields_from_example != '', (
71
+ 'Either fields or fields_from_example must be specified.'
72
+ )
73
+ self.tokenizer = tokenizer
74
+
75
+ def __call__(self, example, has_aux=False):
76
+ if has_aux:
77
+ example, *aux = example
78
+ else:
79
+ aux = tuple()
80
+ token_buffer = []
81
+ loss_mask_buffer = []
82
+
83
+ if self.config.add_bos_token:
84
+ token_buffer.append(self.tokenizer.bos_token_id)
85
+ loss_mask_buffer.append(0.0)
86
+
87
+ if self.config.fields_from_example != '':
88
+ fields = example[self.config.fields_from_example].split(',')
89
+ else:
90
+ fields = self.config.fields.split(',')
91
+
92
+ for i, field in enumerate(fields):
93
+ if field.startswith('[') and field.endswith(']'):
94
+ # No loss for this field.
95
+ field = field[1:-1]
96
+ mask = 0.0
97
+ else:
98
+ mask = 1.0
99
+
100
+ if field.startswith('<|') and field.endswith('|>'):
101
+ # Special tokens.
102
+ field = field[2:-2]
103
+ if field == 'bos':
104
+ token_buffer.append(self.tokenizer.bos_token_id)
105
+ elif field == 'eos':
106
+ token_buffer.append(self.tokenizer.eos_token_id)
107
+ else:
108
+ # Token ID specified directly.
109
+ token_buffer.append(int(field))
110
+ loss_mask_buffer.append(mask)
111
+ elif field.startswith('{') and field.endswith('}'):
112
+ field = field[1:-1]
113
+ # Base64 encoded raw tokens.
114
+ tokens = np.frombuffer(
115
+ base64.b64decode(example[field]),
116
+ dtype=self.config.base64_token_dtype
117
+ ).tolist()
118
+ token_buffer.extend(tokens)
119
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
120
+ else:
121
+ subfields = field.split('+')
122
+ text = self.config.subfield_separator.join(
123
+ [example[subfield] for subfield in subfields]
124
+ )
125
+ if i == 0:
126
+ text = self.config.prepend_text + text
127
+ tokens = self.tokenizer.encode(text)
128
+ token_buffer.extend(tokens)
129
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
130
+
131
+ if self.config.add_eos_token:
132
+ token_buffer.append(self.tokenizer.eos_token_id)
133
+ loss_mask_buffer.append(1.0)
134
+
135
+ return token_buffer, loss_mask_buffer, *aux
136
+
137
+
138
+ class HuggingfaceDataset(object):
139
+ """ Huggingface dataset, where the dataset is loaded using the huggingface
140
+ datasets.load_dataset() function.
141
+ """
142
+
143
+ @staticmethod
144
+ def get_default_config(updates=None):
145
+ config = ConfigDict()
146
+ config.path = 'c4'
147
+ config.name = 'en'
148
+ config.split = 'train'
149
+ config.streaming = False
150
+ config.seq_length = 1024
151
+ config.batch_size = 8
152
+ config.always_start_with_bos = False
153
+ config.start_seek_loc = 0
154
+ config.tokens_count_at_start = 0
155
+ config.batch_token_dtype = 'i4'
156
+
157
+ if updates is not None:
158
+ config.update(ConfigDict(updates).copy_and_resolve_references())
159
+ return config
160
+
161
+ def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
162
+ self.config = self.get_default_config(config)
163
+ name = self.config.name if self.config.name != '' else None
164
+ split = self.config.split if self.config.split != '' else None
165
+ self._tokenizer = tokenizer
166
+ self._text_processor = text_processor
167
+ self._dataset = load_from_disk(
168
+ self.config.path
169
+ )[split]
170
+ self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
171
+ self._eval_dataset = eval_dataset
172
+ self._train_epochs = 0
173
+ self._dataset_loc = self.config.start_seek_loc
174
+ self._total_tokens = self.config.tokens_count_at_start
175
+ self._index = 0
176
+
177
+ def __iter__(self):
178
+ chunk_size = self.config.batch_size * self.config.seq_length
179
+ total_tokens = 0
180
+ while True:
181
+ token_buffer = []
182
+ loss_mask_buffer = []
183
+ if not self._eval_dataset:
184
+ self._shuffle()
185
+ for index, example in enumerate(self._dataset):
186
+ self._index = index
187
+ if not self._eval_dataset and self._dataset_loc > index:
188
+ continue
189
+ tokens, loss_masks = self.text_processor(example)
190
+ token_buffer.extend(tokens)
191
+ loss_mask_buffer.extend(loss_masks)
192
+ while len(token_buffer) > chunk_size + 1:
193
+ self._total_tokens += chunk_size
194
+ metrics = {
195
+ 'dataset_example_index': index,
196
+ 'dataset_total_tokens': self._total_tokens,
197
+ 'epoch': self._train_epochs,
198
+ }
199
+ batch = {
200
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
201
+ self.config.batch_size, -1
202
+ ),
203
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
204
+ self.config.batch_size, -1
205
+ ),
206
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
207
+ self.config.batch_size, -1
208
+ ),
209
+ }
210
+ if self.config.always_start_with_bos:
211
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
212
+ yield batch, metrics
213
+ token_buffer = token_buffer[chunk_size:]
214
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
215
+
216
+ if self._eval_dataset:
217
+ break
218
+ else:
219
+ self._dataset_loc = 0
220
+ self._shuffle()
221
+ self._train_epochs += 1
222
+ print(f"TRAIN {self._train_epochs} EPOCH DONE")
223
+
224
+ def _shuffle(self):
225
+ self._dataset = self._dataset.shuffle(buffer_size=100)
226
+
227
+ def get_state_dict(self):
228
+ return dict(
229
+ config=self.config,
230
+ dataset_loc=self._index,
231
+ total_tokens=self._total_tokens,
232
+ epochs=self._train_epochs,
233
+ )
234
+
235
+ def load_state_dict(self, state_dict):
236
+ if 'config' in state_dict:
237
+ self.config.update(ConfigDict(state_dict['config']))
238
+ self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
239
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
240
+ self._train_epochs = state_dict.get('epochs', 0)
241
+
242
+ @property
243
+ def seq_length(self):
244
+ return self.config.seq_length
245
+
246
+ @property
247
+ def tokenizer(self):
248
+ return self._tokenizer
249
+
250
+ @property
251
+ def text_processor(self):
252
+ return self._text_processor
253
+
254
+ @property
255
+ def dataset(self):
256
+ return self._dataset
257
+
258
+ @property
259
+ def vocab_size(self):
260
+ return len(self._tokenizer)
261
+
262
+
263
+ class JsonDataset(object):
264
+ """ JSON dataset, where each line of the data file contains a JSON
265
+ dictionary with text fields.
266
+ """
267
+
268
+ @staticmethod
269
+ def get_default_config(updates=None):
270
+ config = ConfigDict()
271
+ config.path = ''
272
+ config.seq_length = 1024
273
+ config.batch_size = 8
274
+ config.always_start_with_bos = False
275
+ config.start_seek_loc = 0
276
+ config.example_index_at_start = 0
277
+ config.tokens_count_at_start = 0
278
+ config.tokenizer_processes = 1
279
+ config.tokenizer_parallel_chunk_size = 32
280
+ config.tokenizer_parallel_batch_size = 1024
281
+ config.throughput_average_window_size = 200
282
+
283
+ if updates is not None:
284
+ config.update(ConfigDict(updates).copy_and_resolve_references())
285
+ return config
286
+
287
+ def __init__(self, config, tokenizer, text_processor):
288
+ self.config = self.get_default_config(config)
289
+ assert self.config.path != ''
290
+ self._tokenizer = tokenizer
291
+ self._text_processor = text_processor
292
+ self._index = self.config.example_index_at_start
293
+ self._file_loc = self.config.start_seek_loc
294
+ self._total_tokens = self.config.tokens_count_at_start
295
+
296
+ def parse_json(self, line):
297
+ if not line or line == '\n':
298
+ return None
299
+ try:
300
+ data = json.loads(line)
301
+ except json.decoder.JSONDecodeError:
302
+ print(f'Error parsing json line:\n{line}')
303
+ return None
304
+ return data
305
+
306
+ def json_iterator(self):
307
+ with mlxu.open_file(self.config.path, 'r') as fin:
308
+ fin.seek(self._file_loc)
309
+ while True:
310
+ line = fin.readline()
311
+ self._file_loc = fin.tell()
312
+ if not line: # Reached EOF
313
+ self._index = 0
314
+ fin.seek(0)
315
+ continue
316
+
317
+ data = self.parse_json(line)
318
+ if data is not None:
319
+ # JSON parsing succeeded
320
+ yield data, self._file_loc, self._index
321
+ self._index += 1
322
+
323
+ def batched(self, iterator, batch_size):
324
+ batch = []
325
+ for example in iterator:
326
+ batch.append(example)
327
+ if len(batch) == batch_size:
328
+ yield batch
329
+ batch = []
330
+ if len(batch) > 0:
331
+ yield batch
332
+
333
+ def parallel_example_iterator(self):
334
+ if self.config.tokenizer_processes == 1:
335
+ for example, loc, index in self.json_iterator():
336
+ yield self.text_processor((example, loc, index), has_aux=True)
337
+ else:
338
+ process_pool = Pool(self.config.tokenizer_processes)
339
+ batched_iterator = self.batched(
340
+ self.json_iterator(), self.config.tokenizer_parallel_batch_size
341
+ )
342
+ with process_pool as pool:
343
+ map_fn = partial(self.text_processor, has_aux=True)
344
+ next_batch = pool.map_async(
345
+ map_fn, next(batched_iterator),
346
+ chunksize=self.config.tokenizer_parallel_chunk_size
347
+ )
348
+ while True:
349
+ current_batch = next_batch
350
+ next_batch = pool.map_async(
351
+ map_fn, next(batched_iterator),
352
+ chunksize=self.config.tokenizer_parallel_chunk_size
353
+ )
354
+ for example in current_batch.get():
355
+ yield example
356
+
357
+ def __iter__(self):
358
+ chunk_size = self.config.batch_size * self.config.seq_length
359
+ token_buffer = []
360
+ loss_mask_buffer = []
361
+ last_time = 0.0
362
+ step_times = []
363
+ start_time = time.time()
364
+ start_tokens = self._total_tokens
365
+ for tokens, loss_masks, loc, index in self.parallel_example_iterator():
366
+ token_buffer.extend(tokens)
367
+ loss_mask_buffer.extend(loss_masks)
368
+ while len(token_buffer) > chunk_size + 1:
369
+ self._total_tokens += chunk_size
370
+ step_times.append(time.time() - last_time)
371
+ last_time = time.time()
372
+ if len(step_times) > self.config.throughput_average_window_size:
373
+ step_times = step_times[-self.config.throughput_average_window_size:]
374
+ average_throughput = chunk_size / np.mean(step_times)
375
+ accumulated_throughput = (
376
+ (self._total_tokens - start_tokens) / (time.time() - start_time)
377
+ )
378
+ metrics = {
379
+ 'dataset_file_loc': loc,
380
+ 'dataset_example_index': index,
381
+ 'dataset_total_tokens': self._total_tokens,
382
+ 'dataset_accumulated_tps': accumulated_throughput,
383
+ 'dataset_average_tps': average_throughput,
384
+ }
385
+ batch = {
386
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
387
+ self.config.batch_size, -1
388
+ ),
389
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
390
+ self.config.batch_size, -1
391
+ ),
392
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
393
+ self.config.batch_size, -1
394
+ ),
395
+ }
396
+ if self.config.always_start_with_bos:
397
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
398
+ yield batch, metrics
399
+ token_buffer = token_buffer[chunk_size:]
400
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
401
+
402
+ def get_state_dict(self):
403
+ return dict(
404
+ config=self.config,
405
+ index=self._index,
406
+ file_loc=self._file_loc,
407
+ total_tokens=self._total_tokens,
408
+ )
409
+
410
+ def load_state_dict(self, state_dict):
411
+ if 'config' in state_dict:
412
+ self.config.update(ConfigDict(state_dict['config']))
413
+ self._index = state_dict.get('index', self.config.example_index_at_start)
414
+ self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
415
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
416
+
417
+ @property
418
+ def seq_length(self):
419
+ return self.config.seq_length
420
+
421
+ @property
422
+ def tokenizer(self):
423
+ return self._tokenizer
424
+
425
+ @property
426
+ def text_processor(self):
427
+ return self._text_processor
428
+
429
+ @property
430
+ def vocab_size(self):
431
+ return len(self.tokenizer)
EasyLM/jax_utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4
+ from functools import partial
5
+ import re
6
+ import dataclasses
7
+ import random
8
+ from ml_collections import ConfigDict
9
+ from ml_collections.config_dict.config_dict import placeholder
10
+
11
+ import flax
12
+ import jax
13
+ import jax.numpy as jnp
14
+ from jax.sharding import PartitionSpec as PS
15
+ from jax.sharding import Mesh
16
+ from jax.experimental import mesh_utils
17
+ from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
18
+ from jax.experimental.pjit import pjit
19
+ from jax.interpreters import pxla
20
+ import numpy as np
21
+ from transformers import FlaxLogitsWarper
22
+
23
+
24
+ class JaxRNG(object):
25
+ """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
26
+ pure function.
27
+ """
28
+
29
+ @classmethod
30
+ def from_seed(cls, seed):
31
+ return cls(jax.random.PRNGKey(seed))
32
+
33
+ def __init__(self, rng):
34
+ self.rng = rng
35
+
36
+ def __call__(self, keys=None):
37
+ if keys is None:
38
+ self.rng, split_rng = jax.random.split(self.rng)
39
+ return split_rng
40
+ elif isinstance(keys, int):
41
+ split_rngs = jax.random.split(self.rng, num=keys + 1)
42
+ self.rng = split_rngs[0]
43
+ return tuple(split_rngs[1:])
44
+ else:
45
+ split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
46
+ self.rng = split_rngs[0]
47
+ return {key: val for key, val in zip(keys, split_rngs[1:])}
48
+
49
+
50
+ class JaxDistributedConfig(object):
51
+ """ Utility class for initializing JAX distributed. """
52
+
53
+ @staticmethod
54
+ def get_default_config(updates=None):
55
+ config = ConfigDict()
56
+ config.initialize_jax_distributed = False
57
+ config.coordinator_address = placeholder(str)
58
+ config.num_processes = placeholder(int)
59
+ config.process_id = placeholder(int)
60
+ config.local_device_ids = placeholder(str)
61
+
62
+ if updates is not None:
63
+ config.update(ConfigDict(updates).copy_and_resolve_references())
64
+ return config
65
+
66
+ @classmethod
67
+ def initialize(cls, config):
68
+ config = cls.get_default_config(config)
69
+ if config.initialize_jax_distributed:
70
+ if config.local_device_ids is not None:
71
+ local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
72
+ else:
73
+ local_device_ids = None
74
+
75
+ jax.distributed.initialize(
76
+ coordinator_address=config.coordinator_address,
77
+ num_processes=config.num_processes,
78
+ process_id=config.process_id,
79
+ local_device_ids=local_device_ids,
80
+ )
81
+
82
+
83
+ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
84
+ """ JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
85
+ def __init__(self, temperature):
86
+ self.temperature = temperature
87
+
88
+ def __call__(self, input_ids, scores, cur_len):
89
+ return scores / jnp.clip(self.temperature, a_min=1e-8)
90
+
91
+
92
+ def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
93
+ """ Create pytree of sharding and gathering functions from pytree of
94
+ partition specs.
95
+ """
96
+ float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
97
+
98
+ def make_to_dtype_fn(dtype_spec):
99
+ def to_dtype(tensor):
100
+ if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
101
+ # Convert all float tensors to the same dtype
102
+ return tensor.astype(dtype_specs)
103
+ elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
104
+ return tensor.astype(dtype_spec.dtype)
105
+ return tensor
106
+ return to_dtype
107
+
108
+ def make_shard_fn(partition_spec, dtype_spec=None):
109
+ jax_shard_function = pjit(
110
+ make_to_dtype_fn(dtype_spec),
111
+ in_shardings=None,
112
+ out_shardings=partition_spec
113
+ )
114
+ def shard_fn(tensor):
115
+ return jax_shard_function(tensor).block_until_ready()
116
+ return shard_fn
117
+
118
+ def make_gather_fn(partition_spec, dtype_spec=None):
119
+ jax_gather_fn = pjit(
120
+ make_to_dtype_fn(dtype_spec),
121
+ in_shardings=partition_spec,
122
+ out_shardings=None
123
+ )
124
+ def gather_fn(tensor):
125
+ return jax.device_get(jax_gather_fn(tensor))
126
+ return gather_fn
127
+
128
+ if dtype_specs is None or dtype_specs in float_dtypes:
129
+ shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
130
+ gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
131
+ else:
132
+ shard_fns = jax.tree_util.tree_map(
133
+ make_shard_fn, partition_specs, dtype_specs
134
+ )
135
+ gather_fns = jax.tree_util.tree_map(
136
+ make_gather_fn, partition_specs, dtype_specs
137
+ )
138
+ return shard_fns, gather_fns
139
+
140
+
141
+ def set_random_seed(seed):
142
+ np.random.seed(seed)
143
+ random.seed(seed)
144
+ init_rng(seed)
145
+
146
+
147
+ def get_jax_mesh(axis_dims, names):
148
+ if axis_dims.startswith('!'):
149
+ # Allow splitting a physical mesh axis if needed
150
+ mesh_axis_splitting = True
151
+ axis_dims = axis_dims[1:]
152
+ else:
153
+ mesh_axis_splitting = False
154
+
155
+ if ':' in axis_dims:
156
+ dims = []
157
+ dim_names = []
158
+ for axis in axis_dims.split(','):
159
+ name, dim = axis.split(':')
160
+ assert name in names
161
+ dims.append(int(dim))
162
+ dim_names.append(name)
163
+ assert(set(dim_names) == set(names))
164
+ else:
165
+ dims = [int(x) for x in axis_dims.split(',')]
166
+ dim_names = names
167
+ assert len(dims) == len(names)
168
+ mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
169
+ if mesh_axis_splitting:
170
+ physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
171
+ else:
172
+ physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
173
+ return Mesh(physical_mesh, dim_names)
174
+
175
+
176
+ def names_in_current_mesh(*names):
177
+ """ Check if current mesh axes contain these names. """
178
+ mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
179
+ return set(names) <= set(mesh_axis_names)
180
+
181
+
182
+ def get_names_from_parition_spec(partition_specs):
183
+ """ Return axis names from partition specs. """
184
+ names = set()
185
+ if isinstance(partition_specs, dict):
186
+ partition_specs = partition_specs.values()
187
+ for item in partition_specs:
188
+ if item is None:
189
+ continue
190
+ elif isinstance(item, str):
191
+ names.add(item)
192
+ else:
193
+ names.update(get_names_from_parition_spec(item))
194
+
195
+ return list(names)
196
+
197
+
198
+ def with_sharding_constraint(x, partition_specs):
199
+ """ A smarter version of with_sharding_constraint that only applies the
200
+ constraint if the current mesh contains the axes in the partition specs.
201
+ """
202
+ axis_names = get_names_from_parition_spec(partition_specs)
203
+ if names_in_current_mesh(*axis_names):
204
+ x = _with_sharding_constraint(x, partition_specs)
205
+ return x
206
+
207
+
208
+ def wrap_function_with_rng(rng):
209
+ """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
210
+ def wrap_function(function):
211
+ def wrapped(*args, **kwargs):
212
+ nonlocal rng
213
+ rng, split_rng = jax.random.split(rng)
214
+ return function(split_rng, *args, **kwargs)
215
+ return wrapped
216
+ return wrap_function
217
+
218
+
219
+ def init_rng(seed):
220
+ global jax_utils_rng
221
+ jax_utils_rng = JaxRNG.from_seed(seed)
222
+
223
+
224
+ def next_rng(*args, **kwargs):
225
+ global jax_utils_rng
226
+ return jax_utils_rng(*args, **kwargs)
227
+
228
+
229
+ def get_metrics(metrics, unreplicate=False, stack=False):
230
+ if unreplicate:
231
+ metrics = flax.jax_utils.unreplicate(metrics)
232
+ metrics = jax.device_get(metrics)
233
+ if stack:
234
+ return jax.tree_map(lambda *args: np.stack(args), *metrics)
235
+ else:
236
+ return {key: float(val) for key, val in metrics.items()}
237
+
238
+
239
+ def mse_loss(val, target, valid=None):
240
+ if valid is None:
241
+ valid = jnp.ones((*target.shape[:2], 1))
242
+ valid = valid.astype(jnp.float32)
243
+ loss = jnp.mean(
244
+ jnp.where(
245
+ valid > 0.0,
246
+ jnp.square(val - target),
247
+ 0.0
248
+ )
249
+ )
250
+ return loss
251
+
252
+
253
+ def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
254
+ if valid is None:
255
+ valid = jnp.ones(tokens.shape[:2])
256
+ valid = valid.astype(jnp.float32)
257
+ valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
258
+ logits = logits.astype(jnp.float32) # for numerical stability
259
+ token_log_prob = jnp.squeeze(
260
+ jnp.take_along_axis(
261
+ jax.nn.log_softmax(logits, axis=-1),
262
+ jnp.expand_dims(tokens, -1),
263
+ axis=-1,
264
+ ),
265
+ -1,
266
+ )
267
+ token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
268
+ loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
269
+ correct = jnp.where(
270
+ valid > 0.0,
271
+ jnp.argmax(logits, axis=-1) == tokens,
272
+ jnp.array(False)
273
+ )
274
+ accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
275
+ return loss, accuracy
276
+
277
+
278
+ def global_norm(tree):
279
+ """ Return the global L2 norm of a pytree. """
280
+ squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
281
+ flattened, _ = jax.flatten_util.ravel_pytree(squared)
282
+ return jnp.sqrt(jnp.sum(flattened))
283
+
284
+
285
+ def average_metrics(metrics):
286
+ with jax.spmd_mode("allow_all"):
287
+ return jax.tree_map(
288
+ lambda *args: jnp.mean(jnp.stack(args)),
289
+ *metrics
290
+ )
291
+
292
+
293
+ def get_float_dtype_by_name(dtype):
294
+ return {
295
+ 'bf16': jnp.bfloat16,
296
+ 'bfloat16': jnp.bfloat16,
297
+ 'fp16': jnp.float16,
298
+ 'float16': jnp.float16,
299
+ 'fp32': jnp.float32,
300
+ 'float32': jnp.float32,
301
+ 'fp64': jnp.float64,
302
+ 'float64': jnp.float64,
303
+ }[dtype]
304
+
305
+
306
+ def float_tensor_to_dtype(tensor, dtype):
307
+ if dtype is None or dtype == '':
308
+ return tensor
309
+ if isinstance(dtype, str):
310
+ dtype = get_float_dtype_by_name(dtype)
311
+ float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
312
+ if getattr(tensor, 'dtype', None) in float_dtypes:
313
+ tensor = tensor.astype(dtype)
314
+ return tensor
315
+
316
+
317
+ def float_to_dtype(tree, dtype):
318
+ return jax.tree_util.tree_map(
319
+ partial(float_tensor_to_dtype, dtype=dtype), tree
320
+ )
321
+
322
+
323
+ def get_gradient_checkpoint_policy(name):
324
+ return {
325
+ 'everything_saveable': jax.checkpoint_policies.everything_saveable,
326
+ 'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
327
+ 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
328
+ 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
329
+ }[name]
330
+
331
+
332
+ def tree_path_to_string(path, sep=None):
333
+ keys = []
334
+ for key in path:
335
+ if isinstance(key, jax.tree_util.SequenceKey):
336
+ keys.append(str(key.idx))
337
+ elif isinstance(key, jax.tree_util.DictKey):
338
+ keys.append(str(key.key))
339
+ elif isinstance(key, jax.tree_util.GetAttrKey):
340
+ keys.append(str(key.name))
341
+ elif isinstance(key, jax.tree_util.FlattenedIndexKey):
342
+ keys.append(str(key.key))
343
+ else:
344
+ keys.append(str(key))
345
+ if sep is None:
346
+ return tuple(keys)
347
+ return sep.join(keys)
348
+
349
+
350
+ def flatten_tree(xs, is_leaf=None, sep=None):
351
+ flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
352
+ output = {}
353
+ for key, val in flattened:
354
+ output[tree_path_to_string(key, sep=sep)] = val
355
+ return output
356
+
357
+
358
+ def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
359
+ """ An extended version of jax.tree_util.tree_map, where the mapped function
360
+ f takes both the name (path) and the tree leaf as input.
361
+ """
362
+ return jax.tree_util.tree_map_with_path(
363
+ lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
364
+ tree, *rest,
365
+ is_leaf=is_leaf
366
+ )
367
+
368
+
369
+ def match_partition_rules(rules, params):
370
+ """ Returns a pytree of PartitionSpec according to rules. Supports handling
371
+ Flax TrainState and Optax optimizer state.
372
+ """
373
+ def get_partition_spec(name, leaf):
374
+ if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
375
+ """ Don't partition scalar values. """
376
+ return PS()
377
+ for rule, ps in rules:
378
+ if re.search(rule, name) is not None:
379
+ return ps
380
+ raise ValueError(f'Partition rule not found for param: {name}')
381
+ return named_tree_map(get_partition_spec, params, sep='/')
382
+
383
+
384
+ def get_weight_decay_mask(exclusions):
385
+ """ Return a weight decay mask function that computes the pytree masks
386
+ according to the given exclusion rules.
387
+ """
388
+ def decay(name, _):
389
+ for rule in exclusions:
390
+ if re.search(rule, name) is not None:
391
+ return False
392
+ return True
393
+
394
+ def weight_decay_mask(params):
395
+ return named_tree_map(decay, params, sep='/')
396
+
397
+ return weight_decay_mask
398
+
399
+
400
+ def tree_apply(fns, tree):
401
+ """ Apply a pytree of functions to the pytree. """
402
+ return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
403
+
EasyLM/models/__init__.py ADDED
File without changes
EasyLM/models/gptj/__init__.py ADDED
File without changes
EasyLM/models/gptj/gptj_model.py ADDED
@@ -0,0 +1,1054 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The EleutherAI and The HuggingFace Inc. team.
3
+ # Modifications copyright 2022 Xinyang Geng
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from functools import partial
19
+ from typing import Optional, Tuple
20
+ import json
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen.attention import dot_product_attention_weights
30
+ from flax.traverse_util import flatten_dict, unflatten_dict
31
+ from jax import lax
32
+ from flax.linen import partitioning as nn_partitioning
33
+
34
+ from transformers.configuration_utils import PretrainedConfig
35
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
36
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
37
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from transformers.generation.flax_logits_process import FlaxLogitsProcessorList
39
+ from transformers import AutoTokenizer
40
+ from jax.sharding import PartitionSpec
41
+
42
+ from ml_collections import ConfigDict
43
+ from ml_collections.config_dict import config_dict
44
+ from mlxu import function_args_to_config, load_pickle, open_file
45
+
46
+ from EasyLM.jax_utils import (
47
+ with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
48
+ )
49
+
50
+
51
+ """
52
+ The follow code is taken from
53
+ transformers/src/transformers/models/gptj/configuration_gptj.py
54
+ and modified to work with EasyLM.
55
+ """
56
+
57
+
58
+ GPTJ_STANDARD_CONFIGS = {
59
+ '6b': {
60
+ "vocab_size": 50400,
61
+ "n_positions": 2048,
62
+ "n_embd": 4096,
63
+ "n_layer": 28,
64
+ "n_head": 16,
65
+ "rotary_dim": 64,
66
+ "n_inner": None,
67
+ "activation_function": "gelu_new",
68
+ "layer_norm_epsilon": 1e-5,
69
+ "initializer_range": 0.02,
70
+ "scale_attn_weights": True,
71
+ "use_cache": True,
72
+ "bos_token_id": 50256,
73
+ "eos_token_id": 50256,
74
+ "tie_word_embeddings": False,
75
+ "n_real_tokens": 50257,
76
+ }
77
+ }
78
+
79
+
80
+ class GPTJConfig(PretrainedConfig):
81
+ r"""
82
+ This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
83
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
84
+ defaults will yield a similar configuration to that of the GPT-J
85
+ [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
86
+ [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
87
+ for more information.
88
+ Args:
89
+ vocab_size (`int`, *optional*, defaults to 50400):
90
+ Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
91
+ `inputs_ids` passed when calling [`GPTJModel`].
92
+ n_positions (`int`, *optional*, defaults to 2048):
93
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
94
+ just in case (e.g., 512 or 1024 or 2048).
95
+ n_embd (`int`, *optional*, defaults to 4096):
96
+ Dimensionality of the embeddings and hidden states.
97
+ n_layer (`int`, *optional*, defaults to 28):
98
+ Number of hidden layers in the Transformer encoder.
99
+ n_head (`int`, *optional*, defaults to 16):
100
+ Number of attention heads for each attention layer in the Transformer encoder.
101
+ rotary_dim (`int`, *optional*, defaults to 64):
102
+ Number of dimensions in the embedding that Rotary Position Embedding is applied to.
103
+ n_inner (`int`, *optional*, defaults to 0):
104
+ Dimensionality of the inner feed-forward layers. 0 will set it to 4 times n_embd
105
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
106
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
107
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
108
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
109
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
110
+ The dropout ratio for the embeddings.
111
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
112
+ The dropout ratio for the attention.
113
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
114
+ The epsilon to use in the layer normalization layers.
115
+ initializer_range (`float`, *optional*, defaults to 0.02):
116
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
117
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
118
+ Scale attention weights by dividing by sqrt(hidden_size).
119
+ use_cache (`bool`, *optional*, defaults to `True`):
120
+ Whether or not the model should return the last key/values attentions (not used by all models).
121
+ Example:
122
+ ```python
123
+ >>> from transformers import GPTJModel, GPTJConfig
124
+ >>> # Initializing a GPT-J 6B configuration
125
+ >>> configuration = GPTJConfig()
126
+ >>> # Initializing a model from the configuration
127
+ >>> model = GPTJModel(configuration)
128
+ >>> # Accessing the model configuration
129
+ >>> configuration = model.config
130
+ ```"""
131
+ model_type = "gptj"
132
+ attribute_map = {
133
+ "max_position_embeddings": "n_positions",
134
+ "hidden_size": "n_embd",
135
+ "num_attention_heads": "n_head",
136
+ "num_hidden_layers": "n_layer",
137
+ }
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_size=50400,
142
+ n_positions=2048,
143
+ n_embd=4096,
144
+ n_layer=28,
145
+ n_head=16,
146
+ rotary_dim=64,
147
+ n_inner=None,
148
+ activation_function="gelu_new",
149
+ resid_pdrop=0.0,
150
+ embd_pdrop=0.0,
151
+ attn_pdrop=0.0,
152
+ layer_norm_epsilon=1e-5,
153
+ initializer_range=0.02,
154
+ scale_attn_weights=True,
155
+ use_cache=True,
156
+ bos_token_id=50256,
157
+ eos_token_id=50256,
158
+ tie_word_embeddings=False,
159
+ gradient_checkpointing=True,
160
+ gradient_checkpointing_policy='nothing_saveable',
161
+ n_real_tokens=50257,
162
+ fcm_min_ratio=0.0,
163
+ fcm_max_ratio=0.0,
164
+ **kwargs
165
+ ):
166
+ self.vocab_size = vocab_size
167
+ self.n_positions = n_positions
168
+ self.n_embd = n_embd
169
+ self.n_layer = n_layer
170
+ self.n_head = n_head
171
+ self.n_inner = n_inner
172
+ self.rotary_dim = rotary_dim
173
+ self.activation_function = activation_function
174
+ self.resid_pdrop = resid_pdrop
175
+ self.embd_pdrop = embd_pdrop
176
+ self.attn_pdrop = attn_pdrop
177
+ self.layer_norm_epsilon = layer_norm_epsilon
178
+ self.initializer_range = initializer_range
179
+ self.scale_attn_weights = scale_attn_weights
180
+ self.use_cache = use_cache
181
+ self.gradient_checkpointing = gradient_checkpointing
182
+ self.gradient_checkpointing_policy = gradient_checkpointing_policy
183
+ self.n_real_tokens = n_real_tokens
184
+ self.fcm_min_ratio = fcm_min_ratio
185
+ self.fcm_max_ratio = fcm_max_ratio
186
+ if self.n_real_tokens is None:
187
+ self.n_real_tokens = self.vocab_size
188
+
189
+ self.bos_token_id = bos_token_id
190
+ self.eos_token_id = eos_token_id
191
+
192
+ super().__init__(
193
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
194
+ )
195
+
196
+ @classmethod
197
+ def get_default_config(cls, updates=None):
198
+ none_arg_types = dict(
199
+ n_inner=int,
200
+ rotary_dim=int,
201
+ )
202
+ config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
203
+
204
+ if updates is not None:
205
+ config.update(ConfigDict(updates).copy_and_resolve_references())
206
+
207
+ return config
208
+
209
+ @staticmethod
210
+ def get_jax_mesh(axis_dims):
211
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
212
+
213
+ @staticmethod
214
+ def get_partition_rules():
215
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
216
+ the beginning rules match first. It is important to use
217
+ PartitionSpec() instead of None here because JAX does not treat
218
+ None as a pytree leaf.
219
+ """
220
+ return (
221
+ ('transformer/wte/embedding', PartitionSpec('mp', 'fsdp')),
222
+ ('attn/(k_proj|q_proj|v_proj)/kernel', PartitionSpec('fsdp', 'mp')),
223
+ ('attn/out_proj/kernel', PartitionSpec('mp', 'fsdp')),
224
+ ('mlp/fc_in/kernel', PartitionSpec('fsdp', 'mp')),
225
+ ('mlp/fc_in/bias', PartitionSpec('mp')),
226
+ ('mlp/fc_out/kernel', PartitionSpec('mp', 'fsdp')),
227
+ ('mlp/fc_out/bias', PartitionSpec()),
228
+ ('ln_[0-9]+/bias', PartitionSpec()),
229
+ ('[0-9]+/ln_[0-9]+/scale', PartitionSpec()),
230
+ ('ln_f/bias', PartitionSpec()),
231
+ ('ln_f/scale', PartitionSpec()),
232
+ ('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
233
+ ('lm_head/bias', PartitionSpec('mp')),
234
+ ('.*', PartitionSpec()),
235
+ )
236
+
237
+ @staticmethod
238
+ def get_weight_decay_exclusions():
239
+ return (
240
+ 'ln_[0-9]+/bias', 'ln_[0-9]+/scale', 'ln_f/bias', 'ln_f/scale',
241
+ 'bias'
242
+ )
243
+
244
+ @staticmethod
245
+ def rng_keys():
246
+ return ('params', 'dropout', 'fcm')
247
+
248
+ @staticmethod
249
+ def get_tokenizer_config(updates=None):
250
+ config = ConfigDict()
251
+ config.name = 'EleutherAI/gpt-j-6B'
252
+ config.bos_token = '<|endoftext|>'
253
+ config.eos_token = '<|endoftext|>'
254
+ config.pad_token = '<|extratoken_40|>'
255
+ config.cls_token = '<|extratoken_41|>'
256
+ config.mask_token = '<|extratoken_42|>'
257
+
258
+ if updates is not None:
259
+ config.update(ConfigDict(updates).copy_and_resolve_references())
260
+
261
+ return config
262
+
263
+ @classmethod
264
+ def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
265
+ config = cls.get_tokenizer_config(config)
266
+ return AutoTokenizer.from_pretrained(
267
+ config.name,
268
+ bos_token=config.bos_token,
269
+ eos_token=config.eos_token,
270
+ pad_token=config.pad_token,
271
+ cls_token=config.cls_token,
272
+ mask_token=config.mask_token,
273
+ padding_side=padding_side,
274
+ truncation_side=truncation_side,
275
+ )
276
+
277
+ @staticmethod
278
+ def load_pretrained(name, dtype=jnp.float32):
279
+ with jax.default_device(jax.devices("cpu")[0]):
280
+ params = FlaxGPTJForCausalLM.from_pretrained(
281
+ name, _do_init=False, dtype=dtype
282
+ )[1]
283
+ params = freeze({'params': params})
284
+ return jax.device_get(params)
285
+
286
+ @classmethod
287
+ def load_config(cls, path):
288
+ if path in GPTJ_STANDARD_CONFIGS:
289
+ return cls.from_dict(GPTJ_STANDARD_CONFIGS[path])
290
+ load_type, load_path = path.split('::', 1)
291
+ if load_type == 'pickle':
292
+ return cls.from_dict(load_pickle(load_path)['gptj_config'])
293
+ elif load_type == 'json':
294
+ with open_file(load_path, 'r') as fin:
295
+ raw_config = fin.read()
296
+ return cls.from_dict(json.loads(raw_config))
297
+ elif load_type == 'huggingface':
298
+ return cls.from_pretrained(load_path)
299
+ else:
300
+ raise ValueError(f'Unsupported load config type: {load_type}')
301
+
302
+
303
+ """
304
+ The follow code is taken from
305
+ transformers/src/transformers/models/gptj/modeling_flax_gptj.py
306
+ and modified to work with EasyLM.
307
+ """
308
+
309
+ logger = logging.get_logger(__name__)
310
+
311
+ _CHECKPOINT_FOR_DOC = "gptj"
312
+ _CONFIG_FOR_DOC = "GPTJConfig"
313
+
314
+ remat = nn_partitioning.remat
315
+
316
+
317
+ GPTJ_START_DOCSTRING = r"""
318
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
319
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
320
+ etc.)
321
+ This model is also a Flax Linen
322
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
323
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
324
+ Finally, this model supports inherent JAX features such as:
325
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
326
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
327
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
328
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
329
+ Parameters:
330
+ config ([`GPTJConfig`]): Model configuration class with all the parameters of the model.
331
+ Initializing with a config file does not load the weights associated with the model, only the
332
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
333
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
334
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
335
+ `jax.numpy.bfloat16` (on TPUs).
336
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
337
+ specified all the computation will be performed with the given `dtype`.
338
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
339
+ parameters.**
340
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
341
+ [`~FlaxPreTrainedModel.to_bf16`].
342
+ """
343
+
344
+ GPTJ_INPUTS_DOCSTRING = r"""
345
+ Args:
346
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
347
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
348
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
349
+ [`PreTrainedTokenizer.__call__`] for details.
350
+ [What are input IDs?](../glossary#input-ids)
351
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
352
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
353
+ - 1 for tokens that are **not masked**,
354
+ - 0 for tokens that are **masked**.
355
+ [What are attention masks?](../glossary#attention-mask)
356
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
357
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
358
+ config.max_position_embeddings - 1]`.
359
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
360
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
361
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
362
+ output_attentions (`bool`, *optional*):
363
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
364
+ tensors for more detail.
365
+ output_hidden_states (`bool`, *optional*):
366
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
367
+ more detail.
368
+ return_dict (`bool`, *optional*):
369
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
370
+ """
371
+
372
+
373
+
374
+ def create_sinusoidal_positions(num_pos, dim):
375
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
376
+ sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
377
+ sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)
378
+
379
+ sentinel = dim // 2 + dim % 2
380
+ out = np.zeros((num_pos, dim))
381
+ out[:, 0:sentinel] = sin
382
+ out[:, sentinel:] = cos
383
+
384
+ return jnp.array(out)
385
+
386
+
387
+ def rotate_every_two(tensor):
388
+ rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
389
+ rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
390
+ return rotate_half_tensor
391
+
392
+
393
+ def apply_rotary_pos_emb(tensor, sincos):
394
+ sin_pos, cos_pos = sincos
395
+ sin_pos = sin_pos[:, :, None, :].repeat(2, 3)
396
+ cos_pos = cos_pos[:, :, None, :].repeat(2, 3)
397
+ return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
398
+
399
+
400
+ class FlaxGPTJAttention(nn.Module):
401
+ config: GPTJConfig
402
+ dtype: jnp.dtype = jnp.float32
403
+ causal: bool = True
404
+ is_cross_attention: bool = False
405
+
406
+ def setup(self):
407
+ config = self.config
408
+ self.embed_dim = config.hidden_size
409
+ self.num_heads = config.num_attention_heads
410
+ self.head_dim = self.embed_dim // self.num_heads
411
+
412
+ self.rotary_dim = config.rotary_dim
413
+
414
+ dense = partial(
415
+ nn.Dense,
416
+ self.embed_dim,
417
+ use_bias=False,
418
+ dtype=self.dtype,
419
+ kernel_init=jax.nn.initializers.variance_scaling(
420
+ scale=1.0, mode='fan_in',
421
+ distribution='normal',
422
+ )
423
+ )
424
+
425
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
426
+ self.out_proj = dense()
427
+
428
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
429
+
430
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
431
+
432
+ if self.rotary_dim is not None and self.rotary_dim > 0:
433
+ pos_embd_dim = self.rotary_dim
434
+ else:
435
+ pos_embd_dim = self.embed_dim // self.num_heads
436
+ self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim)
437
+
438
+ def _split_heads(self, hidden_states):
439
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
440
+
441
+ def _merge_heads(self, hidden_states):
442
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
443
+
444
+ @nn.compact
445
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
446
+ """
447
+ This function takes projected key, value states from a single input token and concatenates the states to cached
448
+ states from previous steps. This function is slighly adapted from the official Flax repository:
449
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
450
+ """
451
+ # detect if we're initializing by absence of existing cache data.
452
+ is_initialized = self.has_variable("cache", "cached_key")
453
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
454
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
455
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
456
+
457
+ if is_initialized:
458
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
459
+ # update key, value caches with our new 1d spatial slices
460
+ cur_index = cache_index.value
461
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
462
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
463
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
464
+ cached_key.value = key
465
+ cached_value.value = value
466
+ num_updated_cache_vectors = query.shape[1]
467
+ cache_index.value = cache_index.value + num_updated_cache_vectors
468
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
469
+ pad_mask = jnp.broadcast_to(
470
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
471
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
472
+ )
473
+ attention_mask = combine_masks(pad_mask, attention_mask)
474
+ return key, value, attention_mask
475
+
476
+ def __call__(
477
+ self,
478
+ hidden_states,
479
+ attention_mask,
480
+ position_ids,
481
+ deterministic: bool = True,
482
+ init_cache: bool = False,
483
+ output_attentions: bool = False,
484
+ fcm_mask=None,
485
+ ):
486
+
487
+ query = self.q_proj(hidden_states)
488
+ key = self.k_proj(hidden_states)
489
+ value = self.v_proj(hidden_states)
490
+
491
+ query = self._split_heads(query)
492
+ key = self._split_heads(key)
493
+ value = self._split_heads(value)
494
+
495
+ sincos = jnp.take(self.embed_positions, position_ids, axis=0)
496
+ sincos = jnp.split(sincos, 2, axis=-1)
497
+ # Rotary position embeddings induce some weird issues in multi-host environments, so we remove activation-sharding for keys/query vectors to fix this.
498
+ # key = with_sharding_constraint(key, PartitionSpec("dp", None, None, None))
499
+ # query = with_sharding_constraint(query, PartitionSpec("dp", None, None, None))
500
+ if self.rotary_dim is not None and self.rotary_dim > 0:
501
+ k_rot = key[:, :, :, : self.rotary_dim]
502
+ k_pass = key[:, :, :, self.rotary_dim :]
503
+
504
+ q_rot = query[:, :, :, : self.rotary_dim]
505
+ q_pass = query[:, :, :, self.rotary_dim :]
506
+
507
+ k_rot = apply_rotary_pos_emb(k_rot, sincos)
508
+ q_rot = apply_rotary_pos_emb(q_rot, sincos)
509
+
510
+ key = jnp.concatenate([k_rot, k_pass], axis=-1)
511
+ query = jnp.concatenate([q_rot, q_pass], axis=-1)
512
+ else:
513
+ key = apply_rotary_pos_emb(key, sincos)
514
+ query = apply_rotary_pos_emb(query, sincos)
515
+
516
+ query_length, key_length = query.shape[1], key.shape[1]
517
+
518
+ if self.has_variable("cache", "cached_key"):
519
+ mask_shift = self.variables["cache"]["cache_index"]
520
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
521
+ causal_mask = lax.dynamic_slice(
522
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
523
+ )
524
+ else:
525
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
526
+
527
+ batch_size = hidden_states.shape[0]
528
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
529
+
530
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
531
+ if self.causal:
532
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
533
+ else:
534
+ attention_mask = attention_mask
535
+
536
+ dropout_rng = None
537
+ if not deterministic and self.config.attn_pdrop > 0.0:
538
+ dropout_rng = self.make_rng("dropout")
539
+
540
+ # During fast autoregressive decoding, we feed one position at a time,
541
+ # and cache the keys and values step by step.
542
+ if self.has_variable("cache", "cached_key") or init_cache:
543
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
544
+
545
+ # transform boolean mask into float mask
546
+ attention_bias = lax.select(
547
+ attention_mask > 0,
548
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
549
+ jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
550
+ )
551
+
552
+ # usual dot product attention
553
+ attn_weights = dot_product_attention_weights(
554
+ query,
555
+ key,
556
+ bias=attention_bias,
557
+ dropout_rng=dropout_rng,
558
+ dropout_rate=self.config.attn_pdrop,
559
+ deterministic=deterministic,
560
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
561
+ precision=None,
562
+ )
563
+
564
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
565
+ attn_output = self._merge_heads(attn_output)
566
+ attn_output = self.out_proj(attn_output)
567
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
568
+
569
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
570
+ return outputs
571
+
572
+
573
+ class FlaxGPTJMLP(nn.Module):
574
+ config: GPTJConfig
575
+ intermediate_size: int
576
+ dtype: jnp.dtype = jnp.float32
577
+
578
+ def setup(self):
579
+ embed_dim = self.config.hidden_size
580
+ kernel_init=jax.nn.initializers.variance_scaling(
581
+ scale=1.0, mode='fan_in',
582
+ distribution='normal',
583
+ )
584
+
585
+ self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
586
+ self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
587
+
588
+ self.act = ACT2FN[self.config.activation_function]
589
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
590
+
591
+ def __call__(self, hidden_states, deterministic: bool = True):
592
+ hidden_states = self.fc_in(hidden_states)
593
+ hidden_states = self.act(hidden_states)
594
+ hidden_states = self.fc_out(hidden_states)
595
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
596
+ return hidden_states
597
+
598
+
599
+ class FlaxGPTJBlock(nn.Module):
600
+ config: GPTJConfig
601
+ dtype: jnp.dtype = jnp.float32
602
+
603
+ def setup(self):
604
+ hidden_size = self.config.hidden_size
605
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
606
+
607
+ self.ln_1 = nn.LayerNorm(
608
+ epsilon=self.config.layer_norm_epsilon,
609
+ dtype=jnp.promote_types(self.dtype, jnp.float32)
610
+ )
611
+ self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype)
612
+
613
+ self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype)
614
+
615
+ def __call__(
616
+ self,
617
+ hidden_states,
618
+ attention_mask=None,
619
+ position_ids=None,
620
+ deterministic: bool = True,
621
+ init_cache: bool = False,
622
+ output_attentions: bool = False,
623
+ fcm_mask=None,
624
+ ):
625
+ residual = hidden_states
626
+ hidden_states = self.ln_1(hidden_states)
627
+ attn_outputs = self.attn(
628
+ hidden_states,
629
+ attention_mask=attention_mask,
630
+ position_ids=position_ids,
631
+ deterministic=deterministic,
632
+ init_cache=init_cache,
633
+ output_attentions=output_attentions,
634
+ fcm_mask=fcm_mask,
635
+ )
636
+ attn_output = attn_outputs[0]
637
+
638
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
639
+ # residual connection
640
+ hidden_states = attn_output + feed_forward_hidden_states + residual
641
+
642
+ return (hidden_states,) + attn_outputs[1:]
643
+
644
+
645
+ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
646
+ """
647
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
648
+ models.
649
+ """
650
+
651
+ config_class = GPTJConfig
652
+ base_model_prefix = "transformer"
653
+ module_class: nn.Module = None
654
+
655
+ def __init__(
656
+ self,
657
+ config: GPTJConfig,
658
+ input_shape: Tuple = (1, 1),
659
+ seed: int = 0,
660
+ dtype: jnp.dtype = jnp.float32,
661
+ _do_init: bool = True,
662
+ **kwargs,
663
+ ):
664
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
665
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
666
+
667
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
668
+ # init input tensors
669
+ input_ids = jnp.zeros(input_shape, dtype="i4")
670
+ attention_mask = jnp.ones_like(input_ids)
671
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
672
+ params_rng, dropout_rng = jax.random.split(rng)
673
+ rngs = {"params": params_rng, "dropout": dropout_rng}
674
+
675
+ if self.config.add_cross_attention:
676
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
677
+ encoder_attention_mask = attention_mask
678
+ module_init_outputs = self.module.init(
679
+ rngs,
680
+ input_ids,
681
+ attention_mask,
682
+ position_ids,
683
+ encoder_hidden_states,
684
+ encoder_attention_mask,
685
+ return_dict=False,
686
+ )
687
+ else:
688
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
689
+
690
+ random_params = module_init_outputs["params"]
691
+
692
+ if params is not None:
693
+ random_params = flatten_dict(unfreeze(random_params))
694
+ params = flatten_dict(unfreeze(params))
695
+ for missing_key in self._missing_keys:
696
+ params[missing_key] = random_params[missing_key]
697
+ self._missing_keys = set()
698
+ return freeze(unflatten_dict(params))
699
+ else:
700
+ return random_params
701
+
702
+ def init_cache(self, batch_size, max_length):
703
+ r"""
704
+ Args:
705
+ batch_size (`int`):
706
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
707
+ max_length (`int`):
708
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
709
+ cache.
710
+ """
711
+ # init input variables to retrieve cache
712
+ input_ids = jnp.ones((batch_size, max_length))
713
+ attention_mask = jnp.ones_like(input_ids)
714
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
715
+
716
+ init_variables = self.module.init(
717
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
718
+ )
719
+ return init_variables["cache"]
720
+
721
+ def _get_logits_processor(self,*args, **kwargs) -> FlaxLogitsProcessorList:
722
+ processors = super()._get_logits_processor(*args, **kwargs)
723
+ def squash_extra_tokens(input_ids, scores, cur_len):
724
+ return scores.at[:, self.config.n_real_tokens:].set(-float('inf'))
725
+
726
+ processors.append(squash_extra_tokens)
727
+ return processors
728
+
729
+ @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
730
+ def __call__(
731
+ self,
732
+ input_ids,
733
+ attention_mask=None,
734
+ position_ids=None,
735
+ params: dict = None,
736
+ past_key_values: dict = None,
737
+ dropout_rng: jax.random.PRNGKey = None,
738
+ train: bool = False,
739
+ output_attentions: Optional[bool] = None,
740
+ output_hidden_states: Optional[bool] = None,
741
+ return_dict: Optional[bool] = None,
742
+ ):
743
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
744
+ output_hidden_states = (
745
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
746
+ )
747
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
748
+
749
+ batch_size, sequence_length = input_ids.shape
750
+
751
+ if position_ids is None:
752
+ if past_key_values is not None:
753
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
754
+
755
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
756
+
757
+ if attention_mask is None:
758
+ attention_mask = jnp.ones((batch_size, sequence_length))
759
+
760
+ # Handle any PRNG if needed
761
+ rngs = {}
762
+ if dropout_rng is not None:
763
+ rngs["dropout"] = dropout_rng
764
+
765
+ inputs = {"params": params or self.params}
766
+
767
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
768
+ if past_key_values:
769
+ inputs["cache"] = past_key_values
770
+ mutable = ["cache"]
771
+ else:
772
+ mutable = False
773
+
774
+ outputs = self.module.apply(
775
+ inputs,
776
+ jnp.array(input_ids, dtype="i4"),
777
+ jnp.array(attention_mask, dtype="i4"),
778
+ jnp.array(position_ids, dtype="i4"),
779
+ not train,
780
+ False,
781
+ output_attentions,
782
+ output_hidden_states,
783
+ return_dict,
784
+ rngs=rngs,
785
+ mutable=mutable,
786
+ )
787
+
788
+ # add updated cache to model output
789
+ if past_key_values is not None and return_dict:
790
+ outputs, past_key_values = outputs
791
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
792
+ return outputs
793
+ elif past_key_values is not None and not return_dict:
794
+ outputs, past_key_values = outputs
795
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
796
+
797
+ return outputs
798
+
799
+
800
+ class FlaxGPTJBlockCollection(nn.Module):
801
+ config: GPTJConfig
802
+ dtype: jnp.dtype = jnp.float32
803
+
804
+ def setup(self):
805
+ block = FlaxGPTJBlock
806
+ if self.config.gradient_checkpointing:
807
+ FlaxGPT2CheckpointBlock = remat(
808
+ block, static_argnums=(3, 4, 5),
809
+ policy=get_gradient_checkpoint_policy(
810
+ self.config.gradient_checkpointing_policy
811
+ )
812
+ )
813
+ block = FlaxGPT2CheckpointBlock
814
+ self.blocks = [
815
+ block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
816
+ ]
817
+
818
+ def __call__(
819
+ self,
820
+ hidden_states,
821
+ attention_mask=None,
822
+ position_ids=None,
823
+ deterministic: bool = True,
824
+ init_cache: bool = False,
825
+ output_attentions: bool = False,
826
+ output_hidden_states: bool = False,
827
+ return_dict: bool = True,
828
+ ):
829
+ all_attentions = () if output_attentions else None
830
+ all_hidden_states = () if output_hidden_states else None
831
+
832
+ if not deterministic and self.config.fcm_max_ratio > 0:
833
+ # Apply forgetful causal mask
834
+ batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
835
+ fcm_ratio = jax.random.uniform(
836
+ self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
837
+ minval=self.config.fcm_min_ratio,
838
+ maxval=self.config.fcm_max_ratio
839
+ )
840
+ fcm_mask = jax.random.uniform(
841
+ self.make_rng('fcm'),
842
+ shape=(batch_size, 1, seq_length, seq_length)
843
+ ) > fcm_ratio
844
+ fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
845
+ fcm_mask = fcm_mask.astype('bool')
846
+ else:
847
+ fcm_mask = None
848
+
849
+ for block in self.blocks:
850
+ if output_hidden_states:
851
+ all_hidden_states += (hidden_states,)
852
+
853
+ layer_outputs = block(
854
+ hidden_states,
855
+ attention_mask,
856
+ position_ids,
857
+ deterministic,
858
+ init_cache,
859
+ output_attentions,
860
+ fcm_mask,
861
+ )
862
+ hidden_states = layer_outputs[0]
863
+
864
+ if output_attentions:
865
+ all_attentions += (layer_outputs[1],)
866
+
867
+ # this contains possible `None` values - `FlaxGPTJModule` will filter them out
868
+ outputs = (hidden_states, all_hidden_states, all_attentions)
869
+
870
+ return outputs
871
+
872
+
873
+ class FlaxGPTJModule(nn.Module):
874
+ config: GPTJConfig
875
+ dtype: jnp.dtype = jnp.float32
876
+
877
+ def setup(self):
878
+ self.embed_dim = self.config.hidden_size
879
+
880
+ self.wte = nn.Embed(
881
+ self.config.vocab_size,
882
+ self.config.hidden_size,
883
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
884
+ )
885
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
886
+ self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype)
887
+ self.ln_f = nn.LayerNorm(
888
+ epsilon=self.config.layer_norm_epsilon,
889
+ dtype=jnp.promote_types(self.dtype, jnp.float32)
890
+ )
891
+
892
+ def __call__(
893
+ self,
894
+ input_ids,
895
+ attention_mask,
896
+ position_ids,
897
+ deterministic=True,
898
+ init_cache: bool = False,
899
+ output_attentions: bool = False,
900
+ output_hidden_states: bool = False,
901
+ return_dict: bool = True,
902
+ ):
903
+ input_embeds = self.wte(input_ids.astype("i4"))
904
+
905
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
906
+
907
+ outputs = self.h(
908
+ hidden_states,
909
+ attention_mask,
910
+ position_ids=position_ids,
911
+ deterministic=deterministic,
912
+ init_cache=init_cache,
913
+ output_attentions=output_attentions,
914
+ output_hidden_states=output_hidden_states,
915
+ return_dict=return_dict,
916
+ )
917
+
918
+ hidden_states = outputs[0]
919
+ hidden_states = self.ln_f(hidden_states)
920
+
921
+ if output_hidden_states:
922
+ all_hidden_states = outputs[1] + (hidden_states,)
923
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
924
+ else:
925
+ outputs = (hidden_states,) + outputs[1:]
926
+
927
+ if not return_dict:
928
+ return tuple(v for v in outputs if v is not None)
929
+
930
+ return FlaxBaseModelOutput(
931
+ last_hidden_state=hidden_states,
932
+ hidden_states=outputs[1],
933
+ attentions=outputs[-1],
934
+ )
935
+
936
+
937
+ @add_start_docstrings(
938
+ "The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.",
939
+ GPTJ_START_DOCSTRING,
940
+ )
941
+ class FlaxGPTJModel(FlaxGPTJPreTrainedModel):
942
+ module_class = FlaxGPTJModule
943
+
944
+
945
+ append_call_sample_docstring(
946
+ FlaxGPTJModel,
947
+ _CHECKPOINT_FOR_DOC,
948
+ FlaxCausalLMOutput,
949
+ _CONFIG_FOR_DOC,
950
+ )
951
+
952
+
953
+ class FlaxGPTJForCausalLMModule(nn.Module):
954
+ config: GPTJConfig
955
+ dtype: jnp.dtype = jnp.float32
956
+
957
+ def setup(self):
958
+ self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype)
959
+ self.lm_head = nn.Dense(
960
+ self.config.vocab_size,
961
+ dtype=self.dtype,
962
+ kernel_init=jax.nn.initializers.variance_scaling(
963
+ scale=1.0, mode='fan_in',
964
+ distribution='normal',
965
+ )
966
+ )
967
+
968
+ def __call__(
969
+ self,
970
+ input_ids,
971
+ attention_mask=None,
972
+ position_ids=None,
973
+ deterministic: bool = True,
974
+ init_cache: bool = False,
975
+ output_attentions: bool = False,
976
+ output_hidden_states: bool = False,
977
+ return_dict: bool = True,
978
+ ):
979
+ batch_size, seq_length = input_ids.shape
980
+ if attention_mask is None:
981
+ attention_mask = jnp.ones_like(input_ids)
982
+ if position_ids is None:
983
+ position_ids = jnp.broadcast_to(
984
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
985
+ (batch_size, seq_length)
986
+ )
987
+
988
+ outputs = self.transformer(
989
+ input_ids,
990
+ attention_mask,
991
+ position_ids,
992
+ deterministic=deterministic,
993
+ init_cache=init_cache,
994
+ output_attentions=output_attentions,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ hidden_states = outputs[0]
1000
+
1001
+ if self.config.tie_word_embeddings:
1002
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
1003
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
1004
+ else:
1005
+ lm_logits = self.lm_head(hidden_states)
1006
+
1007
+ if not return_dict:
1008
+ return (lm_logits,) + outputs[1:]
1009
+
1010
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
1011
+
1012
+
1013
+ @add_start_docstrings(
1014
+ """
1015
+ The GPTJ Model transformer with a language modeling head on top.
1016
+ """,
1017
+ GPTJ_START_DOCSTRING,
1018
+ )
1019
+ class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
1020
+ module_class = FlaxGPTJForCausalLMModule
1021
+
1022
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1023
+ # initializing the cache
1024
+ batch_size, seq_length = input_ids.shape
1025
+
1026
+ past_key_values = self.init_cache(batch_size, max_length)
1027
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1028
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
1029
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1030
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1031
+ if attention_mask is not None:
1032
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1033
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1034
+ else:
1035
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1036
+
1037
+ return {
1038
+ "past_key_values": past_key_values,
1039
+ "attention_mask": extended_attention_mask,
1040
+ "position_ids": position_ids,
1041
+ }
1042
+
1043
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1044
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1045
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1046
+ return model_kwargs
1047
+
1048
+
1049
+ append_call_sample_docstring(
1050
+ FlaxGPTJForCausalLM,
1051
+ _CHECKPOINT_FOR_DOC,
1052
+ FlaxCausalLMOutput,
1053
+ _CONFIG_FOR_DOC,
1054
+ )
EasyLM/models/gptj/gptj_serve.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import mlxu
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax.experimental.pjit import pjit
10
+ from jax.sharding import PartitionSpec as PS
11
+ import flax
12
+ from flax import linen as nn
13
+ from flax.jax_utils import prefetch_to_device
14
+ from flax.training.train_state import TrainState
15
+ import optax
16
+ from transformers import GenerationConfig, FlaxLogitsProcessorList
17
+
18
+ from EasyLM.checkpoint import StreamingCheckpointer
19
+ from EasyLM.serving import LMServer
20
+ from EasyLM.jax_utils import (
21
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
22
+ set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
23
+ with_sharding_constraint, FlaxTemperatureLogitsWarper
24
+ )
25
+ from EasyLM.models.gptj.gptj_model import (
26
+ GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
27
+ )
28
+
29
+
30
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
+ seed=42,
32
+ initialize_jax_distributed=False,
33
+ mesh_dim='1,-1,1',
34
+ dtype='bf16',
35
+ input_length=1024,
36
+ seq_length=2048,
37
+ top_k=50,
38
+ top_p=1.0,
39
+ do_sample=True,
40
+ num_beams=1,
41
+ add_bos_token=False,
42
+ load_gptj_config='',
43
+ load_checkpoint='',
44
+ tokenizer=GPTJConfig.get_tokenizer_config(),
45
+ lm_server=LMServer.get_default_config(),
46
+ jax_distributed=JaxDistributedConfig.get_default_config(),
47
+ )
48
+
49
+
50
+ def main(argv):
51
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
52
+ set_random_seed(FLAGS.seed)
53
+
54
+ prefix_tokenizer = GPTJConfig.get_tokenizer(
55
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
56
+ )
57
+ tokenizer = GPTJConfig.get_tokenizer(
58
+ FLAGS.tokenizer, truncation_side='right', padding_side='right'
59
+ )
60
+
61
+ with jax.default_device(jax.devices("cpu")[0]):
62
+ gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
63
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
64
+ if load_type == 'huggingface':
65
+ params = gptj_config.load_pretrained(load_path)
66
+ else:
67
+ _, params = StreamingCheckpointer.load_trainstate_checkpoint(
68
+ FLAGS.load_checkpoint, disallow_trainstate=True
69
+ )
70
+
71
+ hf_model = FlaxGPTJForCausalLM(
72
+ gptj_config,
73
+ input_shape=(1, FLAGS.seq_length),
74
+ seed=FLAGS.seed,
75
+ _do_init=False
76
+ )
77
+
78
+ model_ps = match_partition_rules(
79
+ GPTJConfig.get_partition_rules(), params
80
+ )
81
+ shard_fns, _ = make_shard_and_gather_fns(
82
+ model_ps, get_float_dtype_by_name(FLAGS.dtype)
83
+ )
84
+
85
+ @partial(
86
+ pjit,
87
+ in_shardings=(model_ps, PS(), PS()),
88
+ out_shardings=(PS(), PS(), PS())
89
+ )
90
+ def forward_loglikelihood(params, rng, batch):
91
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
92
+ rng_generator = JaxRNG(rng)
93
+ input_tokens = batch['input_tokens']
94
+ output_tokens = batch['output_tokens']
95
+ input_mask = batch['input_mask']
96
+ output_mask = batch['output_mask']
97
+
98
+ logits = hf_model.module.apply(
99
+ params, input_tokens, attention_mask=input_mask,
100
+ deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
101
+ ).logits
102
+ if gptj_config.n_real_tokens is not None:
103
+ logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
104
+ loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
105
+ logits, output_tokens
106
+ )
107
+ loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
108
+ match_count = jnp.sum(
109
+ (jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
110
+ axis=-1
111
+ )
112
+ total = jnp.sum(output_mask, axis=-1)
113
+ is_greedy = match_count == total
114
+ return loglikelihood, is_greedy, rng_generator()
115
+
116
+
117
+ @partial(
118
+ pjit,
119
+ in_shardings=(model_ps, PS(), PS(), PS()),
120
+ out_shardings=(PS(), PS())
121
+ )
122
+ def forward_generate(params, rng, batch, temperature):
123
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
124
+ rng_generator = JaxRNG(rng)
125
+ output = hf_model.generate(
126
+ batch['input_tokens'],
127
+ attention_mask=batch['attention_mask'],
128
+ params=params['params'],
129
+ prng_key=rng_generator(),
130
+ logits_processor=FlaxLogitsProcessorList(
131
+ [FlaxTemperatureLogitsWarper(temperature)]
132
+ ),
133
+ generation_config=GenerationConfig(
134
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ bos_token_id=tokenizer.bos_token_id,
137
+ eos_token_id=tokenizer.eos_token_id,
138
+ do_sample=FLAGS.do_sample,
139
+ num_beams=FLAGS.num_beams,
140
+ top_k=FLAGS.top_k,
141
+ top_p=FLAGS.top_p,
142
+ )
143
+ ).sequences[:, batch['input_tokens'].shape[1]:]
144
+ return output, rng_generator()
145
+
146
+ @partial(
147
+ pjit,
148
+ in_shardings=(model_ps, PS(), PS()),
149
+ out_shardings=(PS(), PS())
150
+ )
151
+ def forward_greedy_generate(params, rng, batch):
152
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
153
+ rng_generator = JaxRNG(rng)
154
+ output = hf_model.generate(
155
+ batch['input_tokens'],
156
+ attention_mask=batch['attention_mask'],
157
+ params=params['params'],
158
+ prng_key=rng_generator(),
159
+ generation_config=GenerationConfig(
160
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
161
+ pad_token_id=tokenizer.eos_token_id,
162
+ bos_token_id=tokenizer.bos_token_id,
163
+ eos_token_id=tokenizer.eos_token_id,
164
+ do_sample=False,
165
+ num_beams=1,
166
+ )
167
+ ).sequences[:, batch['input_tokens'].shape[1]:]
168
+ return output, rng_generator()
169
+
170
+ mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
171
+ with mesh:
172
+ params = tree_apply(shard_fns, params)
173
+ sharded_rng = next_rng()
174
+
175
+ class ModelServer(LMServer):
176
+
177
+ @staticmethod
178
+ def loglikelihood(prefix_text, text):
179
+ nonlocal sharded_rng
180
+ prefix = prefix_tokenizer(
181
+ prefix_text,
182
+ padding='max_length',
183
+ truncation=True,
184
+ max_length=FLAGS.input_length,
185
+ return_tensors='np',
186
+ )
187
+ inputs = tokenizer(
188
+ text,
189
+ padding='max_length',
190
+ truncation=True,
191
+ max_length=FLAGS.seq_length - FLAGS.input_length,
192
+ return_tensors='np',
193
+ )
194
+ output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
195
+ bos_tokens = np.full(
196
+ (output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
197
+ )
198
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
199
+ input_mask = np.concatenate(
200
+ [prefix.attention_mask, inputs.attention_mask], axis=1
201
+ )
202
+ if FLAGS.add_bos_token:
203
+ bos_mask = np.ones_like(input_mask[:, :1])
204
+ else:
205
+ bos_mask = np.zeros_like(input_mask[:, :1])
206
+
207
+ input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
208
+ output_mask = np.concatenate(
209
+ [np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
210
+ )
211
+ batch = dict(
212
+ input_tokens=input_tokens,
213
+ output_tokens=output_tokens,
214
+ input_mask=input_mask,
215
+ output_mask=output_mask,
216
+ )
217
+ with mesh:
218
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
219
+ params, sharded_rng, batch
220
+ )
221
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
222
+ return loglikelihood, is_greedy
223
+
224
+ @staticmethod
225
+ def loglikelihood_rolling(text):
226
+ nonlocal sharded_rng
227
+ inputs = tokenizer(
228
+ text,
229
+ padding='longest',
230
+ truncation=False,
231
+ max_length=np.iinfo(np.int32).max,
232
+ return_tensors='np',
233
+ )
234
+ batch_size = inputs.input_ids.shape[0]
235
+ output_tokens = inputs.input_ids
236
+ attention_mask = inputs.attention_mask
237
+
238
+ if output_tokens.shape[1] < FLAGS.seq_length:
239
+ padding_length = FLAGS.seq_length - output_tokens.shape[1]
240
+ pad_tokens = np.full(
241
+ (batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
242
+ )
243
+ output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
244
+ pad_mask = np.zeros(
245
+ (batch_size, padding_length), dtype=inputs.attention_mask.dtype
246
+ )
247
+ attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
248
+
249
+ bos_tokens = np.full(
250
+ (batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
251
+ )
252
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
253
+ bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
254
+ total_seq_length = output_tokens.shape[1]
255
+
256
+ total_loglikelihood = 0.0
257
+ total_is_greedy = True
258
+ # Sliding window
259
+ for i in range(0, total_seq_length, FLAGS.seq_length):
260
+ # Last window
261
+ if i + FLAGS.seq_length > total_seq_length:
262
+ last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
263
+ last_output_mask[:, :i - total_seq_length] = 0.0
264
+
265
+ batch = dict(
266
+ input_tokens=input_tokens[:, -FLAGS.seq_length:],
267
+ output_tokens=output_tokens[:, -FLAGS.seq_length:],
268
+ input_mask=attention_mask[:, -FLAGS.seq_length:],
269
+ output_mask=last_output_mask,
270
+ )
271
+
272
+ # Normal window
273
+ else:
274
+ batch = dict(
275
+ input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
276
+ output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
277
+ input_mask=attention_mask[:, i:i + FLAGS.seq_length],
278
+ output_mask=attention_mask[:, i:i + FLAGS.seq_length],
279
+ )
280
+
281
+ with mesh:
282
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
283
+ params, sharded_rng, batch
284
+ )
285
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
286
+
287
+ total_loglikelihood += loglikelihood
288
+ total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
289
+
290
+ return total_loglikelihood, total_is_greedy
291
+
292
+ @staticmethod
293
+ def generate(text, temperature):
294
+ nonlocal sharded_rng
295
+ inputs = prefix_tokenizer(
296
+ text,
297
+ padding='max_length',
298
+ truncation=True,
299
+ max_length=FLAGS.input_length,
300
+ return_tensors='np',
301
+ )
302
+ input_tokens = inputs.input_ids
303
+ input_mask = inputs.attention_mask
304
+ if FLAGS.add_bos_token:
305
+ input_tokens[:, 0] = tokenizer.bos_token_id
306
+ input_mask[:, 0] = 1
307
+ batch = dict(
308
+ input_tokens=input_tokens,
309
+ attention_mask=input_mask,
310
+ )
311
+ with mesh:
312
+ output, sharded_rng = forward_generate(
313
+ params, sharded_rng, batch, temperature
314
+ )
315
+ output = jax.device_get(output)
316
+ output_text = []
317
+ for text in list(tokenizer.batch_decode(output)):
318
+ if tokenizer.eos_token in text:
319
+ text = text.split(tokenizer.eos_token, maxsplit=1)[0]
320
+ output_text.append(text)
321
+
322
+ return output_text
323
+
324
+ @staticmethod
325
+ def greedy_until(prefix_text, until, max_length):
326
+ nonlocal sharded_rng
327
+ all_outputs = []
328
+ for pf, ut in zip(prefix_text, until):
329
+ if isinstance(ut, str):
330
+ ut = [ut]
331
+ total_length = 0
332
+ total_generated = ''
333
+
334
+ while total_length < max_length:
335
+ pf_tokens = tokenizer(
336
+ pf,
337
+ padding=False,
338
+ truncation=False,
339
+ max_length=np.iinfo(np.int32).max,
340
+ return_tensors='np',
341
+ )
342
+ input_tokens = pf_tokens.input_ids
343
+ attention_mask = pf_tokens.attention_mask
344
+
345
+ if input_tokens.shape[1] < FLAGS.input_length:
346
+ extra = FLAGS.input_length - input_tokens.shape[1]
347
+ pad_tokens = np.full(
348
+ (1, extra), tokenizer.pad_token_id, dtype=np.int32
349
+ )
350
+ input_tokens = np.concatenate(
351
+ [pad_tokens, input_tokens], axis=1
352
+ )
353
+ pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
354
+ attention_mask = np.concatenate(
355
+ [pad_attention, attention_mask], axis=1
356
+ )
357
+ elif input_tokens.shape[1] > FLAGS.input_length:
358
+ input_tokens = input_tokens[:, -FLAGS.input_length:]
359
+ attention_mask = attention_mask[:, -FLAGS.input_length:]
360
+
361
+ if FLAGS.add_bos_token:
362
+ input_tokens[:, 0] = tokenizer.bos_token_id
363
+ attention_mask[:, 0] = 1
364
+
365
+ batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
366
+
367
+ with mesh:
368
+ output, sharded_rng = forward_greedy_generate(
369
+ params, sharded_rng, batch
370
+ )
371
+ output = jax.device_get(output)
372
+
373
+ total_length += output.shape[1]
374
+ output_text = tokenizer.batch_decode(output)[0]
375
+ total_generated = total_generated + output_text
376
+ pf = pf + output_text
377
+
378
+ done = False
379
+ for s in ut:
380
+ if s in total_generated:
381
+ total_generated = total_generated.split(s, maxsplit=1)[0]
382
+ done = True
383
+ if done:
384
+ break
385
+
386
+ all_outputs.append(total_generated)
387
+
388
+ return all_outputs
389
+
390
+
391
+ server = ModelServer(FLAGS.lm_server)
392
+ server.run()
393
+
394
+
395
+ if __name__ == "__main__":
396
+ mlxu.run(main)
EasyLM/models/gptj/gptj_train.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ from tqdm import tqdm, trange
5
+ import numpy as np
6
+ import mlxu
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax.experimental.pjit import pjit, with_sharding_constraint
11
+ from jax.sharding import PartitionSpec as PS
12
+ from flax.training.train_state import TrainState
13
+
14
+ from EasyLM.data import DatasetFactory
15
+ from EasyLM.checkpoint import StreamingCheckpointer
16
+ from EasyLM.optimizers import OptimizerFactory
17
+ from EasyLM.jax_utils import (
18
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
19
+ cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
+ set_random_seed, average_metrics, get_weight_decay_mask,
21
+ make_shard_and_gather_fns, tree_apply
22
+ )
23
+ from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
24
+
25
+
26
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
27
+ seed=42,
28
+ mesh_dim='1,-1,1',
29
+ dtype='fp32',
30
+ total_steps=10000,
31
+ load_gptj_config='',
32
+ update_gptj_config='',
33
+ load_checkpoint='',
34
+ load_dataset_state='',
35
+ log_freq=50,
36
+ save_model_freq=0,
37
+ save_milestone_freq=0,
38
+ eval_steps=0,
39
+ tokenizer=GPTJConfig.get_tokenizer_config(),
40
+ train_dataset=DatasetFactory.get_default_config(),
41
+ eval_dataset=DatasetFactory.get_default_config(),
42
+ optimizer=OptimizerFactory.get_default_config(),
43
+ checkpointer=StreamingCheckpointer.get_default_config(),
44
+ gptj=GPTJConfig.get_default_config(),
45
+ logger=mlxu.WandBLogger.get_default_config(),
46
+ log_all_worker=False,
47
+ jax_distributed=JaxDistributedConfig.get_default_config(),
48
+ )
49
+
50
+
51
+ def main(argv):
52
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
53
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
54
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
55
+ logger = mlxu.WandBLogger(
56
+ config=FLAGS.logger,
57
+ variant=variant,
58
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
59
+ )
60
+ set_random_seed(FLAGS.seed)
61
+
62
+ tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
63
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
64
+ if FLAGS.load_dataset_state != '':
65
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
66
+
67
+ if FLAGS.eval_steps > 0:
68
+ eval_dataset = DatasetFactory.load_dataset(
69
+ FLAGS.eval_dataset, dataset.tokenizer
70
+ )
71
+ eval_iterator = iter(eval_dataset)
72
+
73
+ seq_length = dataset.seq_length
74
+
75
+ if FLAGS.load_gptj_config != '':
76
+ gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
77
+ else:
78
+ gptj_config = GPTJConfig(**FLAGS.gptj)
79
+
80
+ if FLAGS.update_gptj_config != '':
81
+ gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
82
+
83
+ gptj_config.update(dict(
84
+ bos_token_id=dataset.tokenizer.bos_token_id,
85
+ eos_token_id=dataset.tokenizer.eos_token_id,
86
+ ))
87
+ if gptj_config.vocab_size < dataset.vocab_size:
88
+ gptj_config.update(dict(vocab_size=dataset.vocab_size))
89
+
90
+ model = FlaxGPTJForCausalLMModule(
91
+ gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
92
+ )
93
+
94
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
95
+ FLAGS.optimizer,
96
+ get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
97
+ )
98
+
99
+ def create_trainstate_from_params(params):
100
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
101
+
102
+ def init_fn(rng):
103
+ rng_generator = JaxRNG(rng)
104
+ params = model.init(
105
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
106
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
107
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
108
+ rngs=rng_generator(gptj_config.rng_keys()),
109
+ )
110
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
111
+
112
+ def train_step(train_state, rng, batch):
113
+ rng_generator = JaxRNG(rng)
114
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
115
+ def loss_and_accuracy(params):
116
+ logits = model.apply(
117
+ params, batch['input_tokens'], deterministic=False,
118
+ rngs=rng_generator(gptj_config.rng_keys()),
119
+ ).logits
120
+ return cross_entropy_loss_and_accuracy(
121
+ logits, batch['target_tokens'], batch['loss_masks']
122
+ )
123
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
124
+ (loss, accuracy), grads = grad_fn(train_state.params)
125
+ train_state = train_state.apply_gradients(grads=grads)
126
+ metrics = dict(
127
+ loss=loss,
128
+ accuracy=accuracy,
129
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
130
+ gradient_norm=global_norm(grads),
131
+ param_norm=global_norm(train_state.params),
132
+ )
133
+ return train_state, rng_generator(), metrics
134
+
135
+ def eval_step(train_state, rng, batch):
136
+ rng_generator = JaxRNG(rng)
137
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
138
+ logits = model.apply(
139
+ train_state.params, batch['input_tokens'], deterministic=True,
140
+ rngs=rng_generator(gptj_config.rng_keys()),
141
+ ).logits
142
+ loss, accuracy = cross_entropy_loss_and_accuracy(
143
+ logits, batch['target_tokens'], batch['loss_masks']
144
+ )
145
+ metrics = dict(
146
+ eval_loss=loss,
147
+ eval_accuracy=accuracy,
148
+ )
149
+ return rng_generator(), metrics
150
+
151
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
152
+ train_state_partition = match_partition_rules(
153
+ GPTJConfig.get_partition_rules(), train_state_shapes
154
+ )
155
+
156
+ shard_fns, gather_fns = make_shard_and_gather_fns(
157
+ train_state_partition, train_state_shapes
158
+ )
159
+ checkpointer = StreamingCheckpointer(
160
+ FLAGS.checkpointer, logger.output_dir,
161
+ enable=jax.process_index() == 0,
162
+ )
163
+
164
+ sharded_init_fn = pjit(
165
+ init_fn,
166
+ in_shardings=PS(),
167
+ out_shardings=train_state_partition
168
+ )
169
+
170
+ sharded_create_trainstate_from_params = pjit(
171
+ create_trainstate_from_params,
172
+ in_shardings=(train_state_partition.params, ),
173
+ out_shardings=train_state_partition,
174
+ donate_argnums=(0, ),
175
+ )
176
+
177
+ sharded_train_step = pjit(
178
+ train_step,
179
+ in_shardings=(train_state_partition, PS(), PS()),
180
+ out_shardings=(train_state_partition, PS(), PS()),
181
+ donate_argnums=(0, 1),
182
+ )
183
+
184
+ sharded_eval_step = pjit(
185
+ eval_step,
186
+ in_shardings=(train_state_partition, PS(), PS()),
187
+ out_shardings=(PS(), PS()),
188
+ donate_argnums=(1,),
189
+ )
190
+
191
+ def save_checkpoint(train_state, milestone=False):
192
+ step = int(jax.device_get(train_state.step))
193
+ metadata = dict(
194
+ step=step,
195
+ variant=variant,
196
+ flags=flags_config_dict,
197
+ gptj_config=gptj_config.to_dict(),
198
+ )
199
+ checkpointer.save_all(
200
+ train_state=train_state,
201
+ gather_fns=gather_fns,
202
+ metadata=metadata,
203
+ dataset=dataset.get_state_dict(),
204
+ milestone=milestone,
205
+ )
206
+
207
+ mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
208
+ with mesh:
209
+ train_state, restored_params = None, None
210
+ if FLAGS.load_checkpoint != '':
211
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
212
+ if load_type == 'huggingface':
213
+ restored_params = tree_apply(
214
+ shard_fns.params, gptj_config.load_pretrained(load_path)
215
+ )
216
+ train_state = None
217
+ else:
218
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
219
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
220
+ )
221
+
222
+ if train_state is None and restored_params is None:
223
+ # Initialize from scratch
224
+ train_state = sharded_init_fn(next_rng())
225
+ elif train_state is None and restored_params is not None:
226
+ # Restore from params but initialize train_state
227
+ train_state = sharded_create_trainstate_from_params(restored_params)
228
+ del restored_params
229
+
230
+ start_step = int(jax.device_get(train_state.step))
231
+
232
+ if FLAGS.save_model_freq > 0:
233
+ save_checkpoint(train_state)
234
+
235
+ sharded_rng = next_rng()
236
+
237
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
238
+
239
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
240
+ train_state, sharded_rng, metrics = sharded_train_step(
241
+ train_state, sharded_rng, batch
242
+ )
243
+
244
+ if step % FLAGS.log_freq == 0:
245
+ if FLAGS.eval_steps > 0:
246
+ eval_metric_list = []
247
+ for _ in range(FLAGS.eval_steps):
248
+ eval_batch, _ = next(eval_iterator)
249
+ sharded_rng, eval_metrics = sharded_eval_step(
250
+ train_state, sharded_rng, eval_batch
251
+ )
252
+ eval_metric_list.append(eval_metrics)
253
+ metrics.update(average_metrics(eval_metric_list))
254
+
255
+ log_metrics = {"step": step}
256
+ log_metrics.update(metrics)
257
+ log_metrics.update(dataset_metrics)
258
+ log_metrics = jax.device_get(log_metrics)
259
+ logger.log(log_metrics)
260
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
261
+
262
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
263
+ save_checkpoint(train_state, milestone=True)
264
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
265
+ save_checkpoint(train_state)
266
+
267
+ if FLAGS.save_model_freq > 0:
268
+ save_checkpoint(train_state)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ mlxu.run(main)
EasyLM/models/llama/convert_easylm_to_hf.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2023 Xinyang Geng
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This script converts LLaMA model checkpoint trained by EsayLM to the
17
+ # HuggingFace transformers LLaMA PyTorch format, which can then be loaded
18
+ # by HuggingFace transformers.
19
+
20
+ import gc
21
+ import json
22
+ import math
23
+ import os
24
+ import shutil
25
+
26
+ import numpy as np
27
+ import mlxu
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import flax
31
+ from flax.traverse_util import flatten_dict
32
+ import torch
33
+ from transformers import LlamaConfig, LlamaForCausalLM
34
+
35
+ from EasyLM.checkpoint import StreamingCheckpointer
36
+ from EasyLM.jax_utils import float_tensor_to_dtype
37
+
38
+
39
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
40
+ load_checkpoint='',
41
+ tokenizer_path='',
42
+ model_size='13b',
43
+ output_dir='',
44
+ )
45
+
46
+
47
+ LLAMA_STANDARD_CONFIGS = {
48
+ 'small': {
49
+ 'vocab_size': 64256,
50
+ 'dim': 768,
51
+ 'intermediate_size': 3072,
52
+ 'n_layers': 12,
53
+ 'n_heads': 12,
54
+ 'norm_eps': 1e-6,
55
+ },
56
+ 'medium': {
57
+ 'vocab_size': 64256,
58
+ 'dim': 1024,
59
+ 'intermediate_size': 4096,
60
+ 'n_layers': 24,
61
+ 'n_heads': 16,
62
+ 'norm_eps': 1e-6,
63
+ },
64
+ 'large': {
65
+ 'vocab_size': 64256,
66
+ 'dim': 1536,
67
+ 'intermediate_size': 6144,
68
+ 'n_layers': 24,
69
+ 'n_heads': 16,
70
+ 'norm_eps': 1e-6,
71
+ },
72
+ 'xlarge': {
73
+ 'vocab_size': 64256,
74
+ 'dim': 2048,
75
+ 'intermediate_size': 8192,
76
+ 'n_layers': 24,
77
+ 'n_heads': 32,
78
+ 'norm_eps': 1e-6,
79
+ },
80
+ '1b': {
81
+ 'vocab_size': 64256,
82
+ 'dim': 2048,
83
+ 'intermediate_size': 5504,
84
+ 'n_layers': 22,
85
+ 'n_heads': 16,
86
+ 'norm_eps': 1e-6,
87
+ },
88
+ '3b': {
89
+ 'vocab_size': 64256,
90
+ 'dim': 3200,
91
+ 'intermediate_size': 8640,
92
+ 'n_layers': 26,
93
+ 'n_heads': 32,
94
+ 'norm_eps': 1e-6,
95
+ },
96
+ '7b': {
97
+ 'vocab_size': 64256,
98
+ 'dim': 4096,
99
+ 'intermediate_size': 11008,
100
+ 'n_layers': 32,
101
+ 'n_heads': 32,
102
+ 'norm_eps': 1e-6,
103
+ },
104
+ '13b': {
105
+ 'vocab_size': 64256,
106
+ 'dim': 5120,
107
+ 'intermediate_size': 13824,
108
+ 'n_layers': 40,
109
+ 'n_heads': 40,
110
+ 'norm_eps': 1e-6,
111
+ },
112
+ '30b': {
113
+ 'vocab_size': 64256,
114
+ 'dim': 6656,
115
+ 'intermediate_size': 17920,
116
+ 'n_layers': 60,
117
+ 'n_heads': 52,
118
+ 'norm_eps': 1e-6,
119
+ },
120
+ '65b': {
121
+ 'vocab_size': 64256,
122
+ 'dim': 8192,
123
+ 'intermediate_size': 22016,
124
+ 'n_layers': 80,
125
+ 'n_heads': 64,
126
+ 'norm_eps': 1e-5,
127
+ },
128
+ }
129
+
130
+
131
+ def match_keywords(string, positives, negatives):
132
+ for positive in positives:
133
+ if positive not in string:
134
+ return False
135
+ for negative in negatives:
136
+ if negative in string:
137
+ return False
138
+ return True
139
+
140
+
141
+ def load_and_convert_checkpoint(path):
142
+ _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
143
+ flax_params = flatten_dict(flax_params['params'], sep='.')
144
+ torch_params = {}
145
+ for key, tensor in flax_params.items():
146
+ if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
147
+ tensor = tensor.T
148
+ torch_params[key] = torch.tensor(
149
+ float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
150
+ )
151
+ return torch_params
152
+
153
+
154
+ def read_json(path):
155
+ with open(path, "r") as f:
156
+ return json.load(f)
157
+
158
+
159
+ def write_json(text, path):
160
+ with open(path, "w") as f:
161
+ json.dump(text, f)
162
+
163
+
164
+ def write_model(loaded, model_path, model_size):
165
+ os.makedirs(model_path, exist_ok=True)
166
+ tmp_model_path = os.path.join(model_path, "tmp")
167
+ os.makedirs(tmp_model_path, exist_ok=True)
168
+
169
+ params = LLAMA_STANDARD_CONFIGS[model_size]
170
+
171
+ n_layers = params["n_layers"]
172
+ n_heads = params["n_heads"]
173
+ dim = params["dim"]
174
+ dims_per_head = dim // n_heads
175
+ base = 10000.0
176
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
177
+
178
+ # permute for sliced rotary
179
+ def permute(w):
180
+ return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
181
+
182
+
183
+ param_count = 0
184
+ index_dict = {"weight_map": {}}
185
+ for layer_i in range(n_layers):
186
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
187
+ state_dict = {
188
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
189
+ loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
190
+ ),
191
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
192
+ loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
193
+ ),
194
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
195
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
196
+
197
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
198
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
199
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
200
+
201
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
202
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
203
+
204
+ }
205
+
206
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
207
+ for k, v in state_dict.items():
208
+ index_dict["weight_map"][k] = filename
209
+ param_count += v.numel()
210
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
211
+
212
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
213
+ # Unsharded
214
+ state_dict = {
215
+ "model.embed_tokens.weight": loaded["transformer.wte.embedding"],
216
+ "model.norm.weight": loaded["transformer.ln_f.kernel"],
217
+ "lm_head.weight": loaded["lm_head.kernel"],
218
+ }
219
+
220
+ for k, v in state_dict.items():
221
+ index_dict["weight_map"][k] = filename
222
+ param_count += v.numel()
223
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
224
+
225
+ # Write configs
226
+ index_dict["metadata"] = {"total_size": param_count * 2}
227
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
228
+
229
+ config = LlamaConfig(
230
+ vocab_size=params["vocab_size"],
231
+ hidden_size=dim,
232
+ intermediate_size=params["intermediate_size"],
233
+ num_attention_heads=params["n_heads"],
234
+ num_hidden_layers=params["n_layers"],
235
+ rms_norm_eps=params["norm_eps"],
236
+ )
237
+ config.save_pretrained(tmp_model_path)
238
+
239
+ # Make space so we can load the model properly now.
240
+ del state_dict
241
+ del loaded
242
+ gc.collect()
243
+
244
+ print("Loading the checkpoint in a Llama model.")
245
+ model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
246
+ # Avoid saving this as part of the config.
247
+ print("Model parameter count", model.num_parameters())
248
+ del model.config._name_or_path
249
+
250
+ print("Saving in the Transformers format.")
251
+ model.save_pretrained(model_path, safe_serialization=True)
252
+ shutil.rmtree(tmp_model_path)
253
+
254
+
255
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
256
+ print(f"Fetching the tokenizer from {input_tokenizer_path}.")
257
+ os.makedirs(tokenizer_path, exist_ok=True)
258
+ write_json(
259
+ {
260
+ "bos_token": {
261
+ "content": "<s>",
262
+ "lstrip": False,
263
+ "normalized": True,
264
+ "rstrip": False,
265
+ "single_word": False
266
+ },
267
+ "eos_token": {
268
+ "content": "</s>",
269
+ "lstrip": False,
270
+ "normalized": True,
271
+ "rstrip": False,
272
+ "single_word": False
273
+ },
274
+ "unk_token": {
275
+ "content": "<unk>",
276
+ "lstrip": False,
277
+ "normalized": True,
278
+ "rstrip": False,
279
+ "single_word": False
280
+ },
281
+ },
282
+ os.path.join(tokenizer_path, "special_tokens_map.json")
283
+ )
284
+ write_json(
285
+ {
286
+ "add_bos_token": True,
287
+ "add_eos_token": False,
288
+ "model_max_length": 2048,
289
+ "pad_token": None,
290
+ "sp_model_kwargs": {},
291
+ "tokenizer_class": "LlamaTokenizer",
292
+ "clean_up_tokenization_spaces": False,
293
+ "bos_token": {
294
+ "__type": "AddedToken",
295
+ "content": "<s>",
296
+ "lstrip": False,
297
+ "normalized": True,
298
+ "rstrip": False,
299
+ "single_word": False
300
+ },
301
+ "eos_token": {
302
+ "__type": "AddedToken",
303
+ "content": "</s>",
304
+ "lstrip": False,
305
+ "normalized": True,
306
+ "rstrip": False,
307
+ "single_word": False
308
+ },
309
+ "unk_token": {
310
+ "__type": "AddedToken",
311
+ "content": "<unk>",
312
+ "lstrip": False,
313
+ "normalized": True,
314
+ "rstrip": False,
315
+ "single_word": False
316
+ },
317
+ },
318
+ os.path.join(tokenizer_path, "tokenizer_config.json"),
319
+ )
320
+ shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
321
+
322
+
323
+ def main(argv):
324
+ assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
325
+ assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
326
+ # write_tokenizer(
327
+ # tokenizer_path=FLAGS.output_dir,
328
+ # input_tokenizer_path=FLAGS.tokenizer_path,
329
+ # )
330
+ write_model(
331
+ load_and_convert_checkpoint(FLAGS.load_checkpoint),
332
+ model_path=FLAGS.output_dir,
333
+ model_size=FLAGS.model_size,
334
+ )
335
+
336
+
337
+ if __name__ == "__main__":
338
+ mlxu.run(main)
EasyLM/models/llama/convert_hf_to_easylm.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python convert_hf_to_easylm.py \
4
+ --checkpoint_dir /path/hf_format_dir/ \
5
+ --output_file /path/easylm_format.stream \
6
+ --model_size 7b \
7
+ --streaming
8
+ """
9
+ import time
10
+ from pathlib import Path
11
+ import argparse
12
+
13
+ import mlxu
14
+ import torch
15
+ import flax
16
+
17
+ from EasyLM.checkpoint import StreamingCheckpointer
18
+
19
+ LLAMA_STANDARD_CONFIGS = {
20
+ '1b': {
21
+ 'dim': 2048,
22
+ 'intermediate_size': 5504,
23
+ 'n_layers': 22,
24
+ 'n_heads': 16,
25
+ 'norm_eps': 1e-6,
26
+ },
27
+ '3b': {
28
+ 'dim': 3200,
29
+ 'intermediate_size': 8640,
30
+ 'n_layers': 26,
31
+ 'n_heads': 32,
32
+ 'norm_eps': 1e-6,
33
+ },
34
+ "7b": {
35
+ "dim": 4096,
36
+ "intermediate_size": 11008,
37
+ "n_layers": 32,
38
+ "n_heads": 32,
39
+ "norm_eps": 1e-6,
40
+ },
41
+ "13b": {
42
+ "dim": 5120,
43
+ "intermediate_size": 13824,
44
+ "n_layers": 40,
45
+ "n_heads": 40,
46
+ "norm_eps": 1e-6,
47
+ },
48
+ "30b": {
49
+ "dim": 6656,
50
+ "intermediate_size": 17920,
51
+ "n_layers": 60,
52
+ "n_heads": 52,
53
+ "norm_eps": 1e-6,
54
+ },
55
+ "65b": {
56
+ "dim": 8192,
57
+ "intermediate_size": 22016,
58
+ "n_layers": 80,
59
+ "n_heads": 64,
60
+ "norm_eps": 1e-5,
61
+ },
62
+ }
63
+
64
+
65
+ def inverse_permute(params, w):
66
+ n_layers = params["n_layers"]
67
+ n_heads = params["n_heads"]
68
+ dim = params["dim"]
69
+ reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
70
+ transposed_w = reshaped_w.transpose(0, 2, 1, 3)
71
+ inverted_w = transposed_w.reshape(dim, dim)
72
+ return inverted_w
73
+
74
+
75
+ def main(args):
76
+ start = time.time()
77
+ params = LLAMA_STANDARD_CONFIGS[args.model_size]
78
+
79
+ ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
80
+ ckpt = {}
81
+ for i, ckpt_path in enumerate(ckpt_paths):
82
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
83
+ for k, v in checkpoint.items():
84
+ if k.startswith("model."):
85
+ k = k[6:]
86
+ ckpt[k] = v
87
+ print(f"Start convert weight to easylm format...")
88
+ jax_weights = {
89
+ "transformer": {
90
+ "wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
91
+ "ln_f": {"kernel": ckpt["norm.weight"].numpy()},
92
+ "h": {
93
+ "%d"
94
+ % (layer): {
95
+ "attention": {
96
+ "wq": {
97
+ "kernel": inverse_permute(
98
+ params,
99
+ ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
100
+ ).transpose()
101
+ },
102
+ "wk": {
103
+ "kernel": inverse_permute(
104
+ params,
105
+ ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
106
+ ).transpose()
107
+ },
108
+ "wv": {
109
+ "kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
110
+ .numpy()
111
+ .transpose()
112
+ },
113
+ "wo": {
114
+ "kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
115
+ .numpy()
116
+ .transpose()
117
+ },
118
+ },
119
+ "feed_forward": {
120
+ "w1": {
121
+ "kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
122
+ .numpy()
123
+ .transpose()
124
+ },
125
+ "w2": {
126
+ "kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
127
+ .numpy()
128
+ .transpose()
129
+ },
130
+ "w3": {
131
+ "kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
132
+ .numpy()
133
+ .transpose()
134
+ },
135
+ },
136
+ "attention_norm": {
137
+ "kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
138
+ },
139
+ "ffn_norm": {
140
+ "kernel": ckpt[
141
+ f"layers.{layer}.post_attention_layernorm.weight"
142
+ ].numpy()
143
+ },
144
+ }
145
+ for layer in range(params["n_layers"])
146
+ },
147
+ },
148
+ "lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
149
+ }
150
+ print(f"Convert weight to easylm format finished...")
151
+ print(f"Start to save...")
152
+
153
+ if args.streaming:
154
+ StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
155
+ else:
156
+ with mlxu.open_file(args.output_file, "wb") as fout:
157
+ fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
158
+
159
+ print(
160
+ f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
161
+ )
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser(description="hf to easylm format script")
166
+
167
+ parser.add_argument(
168
+ "--checkpoint_dir",
169
+ type=str,
170
+ help="Need to be converted model weight dir. it is a dir",
171
+ )
172
+ parser.add_argument(
173
+ "--output_file", type=str, help="Save model weight file path, it is a file."
174
+ )
175
+ parser.add_argument(
176
+ "--model_size",
177
+ type=str,
178
+ default="7b",
179
+ choices=["7b", "13b", "30b", "65b"],
180
+ help="model size",
181
+ )
182
+ parser.add_argument(
183
+ "--streaming",
184
+ action="store_true",
185
+ default=True,
186
+ help="whether is model weight saved stream format",
187
+ )
188
+
189
+ args = parser.parse_args()
190
+
191
+ print(f"checkpoint_dir: {args.checkpoint_dir}")
192
+ print(f"output_file: {args.output_file}")
193
+ print(f"model_size: {args.model_size}")
194
+ print(f"streaming: {args.streaming}")
195
+
196
+ main(args)
EasyLM/models/llama/convert_torch_to_easylm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts the standrd LLaMA PyTorch checkpoint released by Meta
2
+ # to the EasyLM checkpoint format. The converted checkpoint can then be loaded
3
+ # by EasyLM for fine-tuning or inference.
4
+
5
+ # This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
6
+
7
+ from pathlib import Path
8
+ import json
9
+ import numpy as np
10
+ import torch
11
+ import flax
12
+ import mlxu
13
+
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+
16
+
17
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
18
+ checkpoint_dir='',
19
+ output_file='',
20
+ streaming=True,
21
+ )
22
+
23
+
24
+ def main(argv):
25
+ ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
26
+ ckpts = {}
27
+ for i, ckpt_path in enumerate(ckpt_paths):
28
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
29
+ ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
30
+ ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
31
+ with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
32
+ params = json.loads(f.read())
33
+
34
+ jax_weights = {
35
+ 'transformer': {
36
+ 'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
37
+ 'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
38
+ 'h': {
39
+ '%d' % (layer): {
40
+ 'attention': {
41
+ 'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
42
+ 'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
43
+ 'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
44
+ 'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
45
+ },
46
+ 'feed_forward': {
47
+ 'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
48
+ 'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
49
+ 'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
50
+ },
51
+ 'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
52
+ 'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
53
+ }
54
+ for layer in range(params['n_layers'])},
55
+ },
56
+ 'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
57
+ }
58
+ if FLAGS.streaming:
59
+ StreamingCheckpointer.save_train_state_to_file(
60
+ jax_weights, FLAGS.output_file
61
+ )
62
+ else:
63
+ with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
64
+ fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
65
+
66
+
67
+ if __name__ == '__main__':
68
+ mlxu.run(main)
EasyLM/models/llama/llama_model.py ADDED
@@ -0,0 +1,1360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ import json
5
+ import tempfile
6
+ from functools import partial
7
+
8
+ import numpy as np
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jax import lax
12
+ from jax.sharding import PartitionSpec as PS
13
+ import flax.linen as nn
14
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
15
+ from flax.linen import combine_masks, make_causal_mask
16
+ from flax.linen.attention import dot_product_attention_weights
17
+ from flax.traverse_util import flatten_dict, unflatten_dict
18
+ from flax.linen import partitioning as nn_partitioning
19
+ import einops
20
+
21
+ import sentencepiece as spm
22
+ from transformers import AutoTokenizer
23
+ from transformers.configuration_utils import PretrainedConfig
24
+ from transformers.utils import logging
25
+ from transformers.tokenization_utils import PreTrainedTokenizer
26
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
27
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
28
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
29
+
30
+ from ml_collections import ConfigDict
31
+ from ml_collections.config_dict import config_dict
32
+ from mlxu import function_args_to_config, load_pickle, open_file
33
+
34
+ from EasyLM.bpt import blockwise_ffn, blockwise_attn
35
+ from EasyLM.jax_utils import (
36
+ with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
37
+ )
38
+
39
+
40
+ LLAMA_STANDARD_CONFIGS = {
41
+ 'small': {
42
+ 'vocab_size': 64256,
43
+ 'hidden_size': 768,
44
+ 'intermediate_size': 3072,
45
+ 'num_hidden_layers': 12,
46
+ 'num_attention_heads': 12,
47
+ 'max_sequence_length': 2048,
48
+ 'initializer_range': 0.02,
49
+ 'rms_norm_eps': 1e-6,
50
+ 'use_cache': True,
51
+ 'tie_word_embeddings': False,
52
+ },
53
+ 'medium': {
54
+ 'vocab_size': 64256,
55
+ 'hidden_size': 1024,
56
+ 'intermediate_size': 4096,
57
+ 'num_hidden_layers': 24,
58
+ 'num_attention_heads': 16,
59
+ 'max_sequence_length': 2048,
60
+ 'initializer_range': 0.02,
61
+ 'rms_norm_eps': 1e-6,
62
+ 'use_cache': True,
63
+ 'tie_word_embeddings': False,
64
+ },
65
+ 'large': {
66
+ 'vocab_size': 64256,
67
+ 'hidden_size': 1536,
68
+ 'intermediate_size': 6144,
69
+ 'num_hidden_layers': 24,
70
+ 'num_attention_heads': 16,
71
+ 'max_sequence_length': 2048,
72
+ 'initializer_range': 0.02,
73
+ 'rms_norm_eps': 1e-6,
74
+ 'use_cache': True,
75
+ 'tie_word_embeddings': False,
76
+ },
77
+ 'xlarge': {
78
+ 'vocab_size': 64256,
79
+ 'hidden_size': 2048,
80
+ 'intermediate_size': 8192,
81
+ 'num_hidden_layers': 24,
82
+ 'num_attention_heads': 32,
83
+ 'max_sequence_length': 2048,
84
+ 'initializer_range': 0.02,
85
+ 'rms_norm_eps': 1e-6,
86
+ 'use_cache': True,
87
+ 'tie_word_embeddings': False,
88
+ },
89
+ '1b': {
90
+ 'vocab_size': 64256,
91
+ 'hidden_size': 2048,
92
+ 'intermediate_size': 5504,
93
+ 'num_hidden_layers': 22,
94
+ 'num_attention_heads': 16,
95
+ 'max_sequence_length': 2048,
96
+ 'initializer_range': 0.02,
97
+ 'rms_norm_eps': 1e-6,
98
+ 'use_cache': True,
99
+ 'tie_word_embeddings': False,
100
+ },
101
+ '3b': {
102
+ 'vocab_size': 64256,
103
+ 'hidden_size': 3200,
104
+ 'intermediate_size': 8640,
105
+ 'num_hidden_layers': 26,
106
+ 'num_attention_heads': 32,
107
+ 'max_sequence_length': 2048,
108
+ 'initializer_range': 0.02,
109
+ 'rms_norm_eps': 1e-6,
110
+ 'use_cache': True,
111
+ 'tie_word_embeddings': False,
112
+ },
113
+ '7b': {
114
+ 'vocab_size': 64256,
115
+ 'hidden_size': 4096,
116
+ 'intermediate_size': 11008,
117
+ 'num_hidden_layers': 32,
118
+ 'num_attention_heads': 32,
119
+ 'max_sequence_length': 2048,
120
+ 'initializer_range': 0.02,
121
+ 'rms_norm_eps': 1e-6,
122
+ 'use_cache': True,
123
+ 'tie_word_embeddings': False,
124
+ },
125
+ '13b': {
126
+ 'vocab_size': 64256,
127
+ 'hidden_size': 5120,
128
+ 'intermediate_size': 13824,
129
+ 'num_hidden_layers': 40,
130
+ 'num_attention_heads': 40,
131
+ 'max_sequence_length': 2048,
132
+ 'initializer_range': 0.02,
133
+ 'rms_norm_eps': 1e-6,
134
+ 'use_cache': True,
135
+ 'tie_word_embeddings': False,
136
+ },
137
+ '30b': {
138
+ 'vocab_size': 64256,
139
+ 'hidden_size': 6656,
140
+ 'intermediate_size': 17920,
141
+ 'num_hidden_layers': 60,
142
+ 'num_attention_heads': 52,
143
+ 'max_sequence_length': 2048,
144
+ 'initializer_range': 0.02,
145
+ 'rms_norm_eps': 1e-6,
146
+ 'use_cache': True,
147
+ 'tie_word_embeddings': False,
148
+ },
149
+ '65b': {
150
+ 'vocab_size': 64256,
151
+ 'hidden_size': 8192,
152
+ 'intermediate_size': 22016,
153
+ 'num_hidden_layers': 80,
154
+ 'num_attention_heads': 64,
155
+ 'max_sequence_length': 2048,
156
+ 'initializer_range': 0.02,
157
+ 'rms_norm_eps': 1e-5,
158
+ 'use_cache': True,
159
+ 'tie_word_embeddings': False,
160
+ },
161
+ 'debug': { # A small model for debugging
162
+ 'vocab_size': 64256,
163
+ 'hidden_size': 128,
164
+ 'intermediate_size': 256,
165
+ 'num_hidden_layers': 2,
166
+ 'num_attention_heads': 4,
167
+ 'max_sequence_length': 2048,
168
+ 'initializer_range': 0.02,
169
+ 'rms_norm_eps': 1e-6,
170
+ 'use_cache': True,
171
+ 'tie_word_embeddings': False,
172
+ },
173
+ }
174
+
175
+
176
+ class LLaMAConfig(PretrainedConfig):
177
+ r"""
178
+ This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA
179
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
180
+ defaults will yield a similar configuration to that of the LLaMA-7B.
181
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
182
+ documentation from [`PretrainedConfig`] for more information.
183
+ Args:
184
+ vocab_size (`int`, *optional*, defaults to 32000):
185
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
186
+ `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`].
187
+ hidden_size (`int`, *optional*, defaults to 4096):
188
+ Dimension of the hidden representations.
189
+ intermediate_size (`int`, *optional*, defaults to 11008):
190
+ Dimension of the MLP representations.
191
+ num_hidden_layers (`int`, *optional*, defaults to 32):
192
+ Number of hidden layers in the Transformer encoder.
193
+ num_attention_heads (`int`, *optional*, defaults to 32):
194
+ Number of attention heads for each attention layer in the Transformer encoder.
195
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
196
+ The non-linear activation function (function or string) in the decoder.
197
+ max_sequence_length (`int`, *optional*, defaults to 2048):
198
+ Max sequence length for model (for RoPE computation)
199
+ initializer_range (`float`, *optional*, defaults to 0.02):
200
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
201
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
202
+ The epsilon used by the rms normalization layers.
203
+ use_cache (`bool`, *optional*, defaults to `True`):
204
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
205
+ relevant if `config.is_decoder=True`.
206
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
207
+ Whether to tie weight embeddings
208
+ Example:
209
+ ```python
210
+ >>> from transformers import LLaMAModel, LLaMAConfig
211
+ >>> # Initializing a LLaMA llama-7b style configuration
212
+ >>> configuration = LLaMAConfig()
213
+ >>> # Initializing a model from the llama-7b style configuration
214
+ >>> model = LLaMAModel(configuration)
215
+ >>> # Accessing the model configuration
216
+ >>> configuration = model.config
217
+ ```"""
218
+ model_type = "llama"
219
+
220
+ def __init__(
221
+ self,
222
+ vocab_size=32000,
223
+ hidden_size=4096,
224
+ intermediate_size=11008,
225
+ num_hidden_layers=32,
226
+ num_attention_heads=32,
227
+ max_sequence_length=2048,
228
+ rms_norm_eps=1e-6,
229
+ initializer_range=0.02,
230
+ use_cache=True,
231
+ # pad_token_id=-1,
232
+ bos_token_id=0,
233
+ eos_token_id=1,
234
+ resid_pdrop=0.0,
235
+ embd_pdrop=0.0,
236
+ attn_pdrop=0.0,
237
+ tie_word_embeddings=False,
238
+ remat_block='nothing_saveable',
239
+ remat_attention='',
240
+ remat_mlp='',
241
+ scan_attention=False,
242
+ scan_mlp=False,
243
+ scan_query_chunk_size=1024,
244
+ scan_key_chunk_size=1024,
245
+ scan_mlp_chunk_size=1024,
246
+ fcm_min_ratio=0.0,
247
+ fcm_max_ratio=0.0,
248
+ **kwargs,
249
+ ):
250
+ self.vocab_size = vocab_size
251
+ self.hidden_size = hidden_size
252
+ self.initializer_range = initializer_range
253
+ self.intermediate_size = intermediate_size
254
+ self.num_hidden_layers = num_hidden_layers
255
+ self.num_attention_heads = num_attention_heads
256
+ self.max_sequence_length = max_sequence_length
257
+ self.rms_norm_eps = rms_norm_eps
258
+ self.use_cache = use_cache
259
+ self.resid_pdrop = resid_pdrop
260
+ self.embd_pdrop = embd_pdrop
261
+ self.attn_pdrop = attn_pdrop
262
+ self.remat_block = remat_block
263
+ self.remat_attention = remat_attention
264
+ self.remat_mlp = remat_mlp
265
+ self.scan_attention = scan_attention
266
+ self.scan_mlp = scan_mlp
267
+ self.scan_query_chunk_size = scan_query_chunk_size
268
+ self.scan_key_chunk_size = scan_key_chunk_size
269
+ self.scan_mlp_chunk_size = scan_mlp_chunk_size
270
+ self.fcm_min_ratio = fcm_min_ratio
271
+ self.fcm_max_ratio = fcm_max_ratio
272
+ super().__init__(
273
+ # pad_token_id=pad_token_id,
274
+ bos_token_id=bos_token_id,
275
+ eos_token_id=eos_token_id,
276
+ tie_word_embeddings=tie_word_embeddings,
277
+ **kwargs,
278
+ )
279
+
280
+ @classmethod
281
+ def get_default_config(cls, updates=None):
282
+ config = function_args_to_config(cls.__init__)
283
+
284
+ if updates is not None:
285
+ config.update(ConfigDict(updates).copy_and_resolve_references())
286
+
287
+ return config
288
+
289
+ @staticmethod
290
+ def get_jax_mesh(axis_dims):
291
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
292
+
293
+ @staticmethod
294
+ def get_partition_rules():
295
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
296
+ the beginning rules match first. It is important to use
297
+ PartitionSpec() instead of None here because JAX does not treat
298
+ None as a pytree leaf.
299
+ """
300
+ return (
301
+ # embeddings
302
+ ("transformer/wte/embedding", PS("mp", "fsdp")),
303
+ # atention
304
+ ("attention/(wq|wk|wv)/kernel", PS("fsdp", "mp")),
305
+ ("attention/wo/kernel", PS("mp", "fsdp")),
306
+ # mlp
307
+ ("feed_forward/w1/kernel", PS("fsdp", "mp")),
308
+ ("feed_forward/w2/kernel", PS("mp", "fsdp")),
309
+ ("feed_forward/w3/kernel", PS("fsdp", "mp")),
310
+ # layer norms
311
+ ("attention_norm/kernel", PS(None)),
312
+ ("ffn_norm/kernel", PS(None)),
313
+ # output head
314
+ ("transformer/ln_f/kernel", PS(None)),
315
+ ("lm_head/kernel", PS("fsdp", "mp")),
316
+ ('.*', PS(None)),
317
+ )
318
+
319
+ @staticmethod
320
+ def get_weight_decay_exclusions():
321
+ return (
322
+ "attention_norm/kernel",
323
+ "ffn_norm/kernel",
324
+ "transformer/ln_f/kernel",
325
+ )
326
+
327
+ @staticmethod
328
+ def rng_keys():
329
+ return ('params', 'dropout', 'fcm')
330
+
331
+ @staticmethod
332
+ def get_tokenizer_config(updates=None):
333
+ config = ConfigDict()
334
+ config.vocab_file = ''
335
+ config.pretrained_model_name_or_path = ''
336
+ config.add_bos_token = False
337
+ config.add_eos_token = False
338
+
339
+ if updates is not None:
340
+ config.update(ConfigDict(updates).copy_and_resolve_references())
341
+ return config
342
+
343
+ @classmethod
344
+ def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
345
+ config = cls.get_tokenizer_config(config)
346
+ if config.vocab_file == '':
347
+ assert config.pretrained_model_name_or_path != '', 'vocab_file or pretrained_model_name_or_path must be specified'
348
+
349
+ if config.pretrained_model_name_or_path != '':
350
+ tokenizer = AutoTokenizer.from_pretrained(
351
+ config.pretrained_model_name_or_path,
352
+ add_bos_token=config.add_bos_token,
353
+ add_eos_token=config.add_eos_token,
354
+ padding_side=padding_side,
355
+ truncation_side=truncation_side,
356
+ )
357
+ else:
358
+ tokenizer = LLaMATokenizer(
359
+ vocab_file=config.vocab_file,
360
+ add_bos_token=config.add_bos_token,
361
+ add_eos_token=config.add_eos_token,
362
+ padding_side=padding_side,
363
+ truncation_side=truncation_side,
364
+ )
365
+ return tokenizer
366
+
367
+ @classmethod
368
+ def load_config(cls, path):
369
+ if path in LLAMA_STANDARD_CONFIGS:
370
+ return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
371
+ load_type, load_path = path.split('::', 1)
372
+ if load_type == 'pickle':
373
+ return cls.from_dict(load_pickle(load_path)['llama_config'])
374
+ elif load_type == 'json':
375
+ with open_file(load_path, 'r') as fin:
376
+ raw_config = fin.read()
377
+ return cls.from_dict(json.loads(raw_config))
378
+ else:
379
+ raise ValueError(f'Unsupported load config type: {load_type}')
380
+
381
+
382
+ remat = nn_partitioning.remat
383
+
384
+ logger = logging.get_logger(__name__)
385
+
386
+
387
+ class RMSNorm(nn.Module):
388
+ dim: int
389
+ eps: float=1e-6
390
+ dtype: jnp.dtype=jnp.float32
391
+ param_dtype: jnp.dtype=jnp.float32
392
+
393
+ def setup(self) -> None:
394
+ self.weight = self.param(
395
+ 'kernel',
396
+ nn.initializers.ones,
397
+ (self.dim,),
398
+ self.param_dtype,
399
+ )
400
+
401
+ def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
402
+ return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
403
+
404
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
405
+ x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
406
+ output = self._norm(x).astype(self.dtype)
407
+ weight = jnp.asarray(self.weight, self.dtype)
408
+ return output * weight
409
+
410
+ def precompute_freqs_cis(dim: int, end: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
411
+ freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
412
+ t = np.arange(end) # type: ignore
413
+ freqs = np.outer(t, freqs).astype(dtype) # type: ignore
414
+ sin, cos = np.sin(freqs), np.cos(freqs)
415
+ freqs_cis = np.complex64(cos + 1j * sin)
416
+ return jnp.asarray(freqs_cis)
417
+
418
+ def apply_rotary_emb(
419
+ xq: jnp.ndarray,
420
+ xk: jnp.ndarray,
421
+ freqs_cis: jnp.ndarray,
422
+ dtype: jnp.dtype=jnp.float32,
423
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
424
+
425
+ reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
426
+ reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
427
+
428
+ xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
429
+ xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
430
+
431
+ # add head dim
432
+ freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
433
+
434
+ xq_out = xq_ * freqs_cis
435
+ xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
436
+
437
+ xk_out = xk_ * freqs_cis
438
+ xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
439
+
440
+ return xq_out.astype(dtype), xk_out.astype(dtype)
441
+
442
+
443
+ class FlaxLLaMAAttention(nn.Module):
444
+ config: LLaMAConfig
445
+ dtype: jnp.dtype=jnp.float32
446
+ param_dtype: jnp.dtype=jnp.float32
447
+ precision: Optional[Union[jax.lax.Precision, str]]=None
448
+
449
+ def setup(self):
450
+ config = self.config
451
+ self.embed_dim = config.hidden_size
452
+ self.num_heads = config.num_attention_heads
453
+ self.head_dim = self.embed_dim // self.num_heads
454
+
455
+ self.wq = nn.Dense(
456
+ config.num_attention_heads*self.head_dim,
457
+ dtype=self.dtype,
458
+ param_dtype=self.param_dtype,
459
+ use_bias=False,
460
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
461
+ precision=self.precision,
462
+ )
463
+ self.wk = nn.Dense(
464
+ config.num_attention_heads*self.head_dim,
465
+ dtype=self.dtype,
466
+ param_dtype=self.param_dtype,
467
+ use_bias=False,
468
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
469
+ precision=self.precision,
470
+ )
471
+ self.wv = nn.Dense(
472
+ config.num_attention_heads*self.head_dim,
473
+ dtype=self.dtype,
474
+ param_dtype=self.param_dtype,
475
+ use_bias=False,
476
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
477
+ precision=self.precision,
478
+ )
479
+ self.wo = nn.Dense(
480
+ config.hidden_size,
481
+ dtype=self.dtype,
482
+ param_dtype=self.param_dtype,
483
+ use_bias=False,
484
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
485
+ precision=self.precision,
486
+ )
487
+
488
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
489
+
490
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
491
+
492
+ self.freqs_cis = precompute_freqs_cis(
493
+ self.head_dim,
494
+ config.max_sequence_length * 2,
495
+ dtype=self.dtype,
496
+ )
497
+
498
+ def _split_heads(self, hidden_states):
499
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
500
+
501
+ def _merge_heads(self, hidden_states):
502
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
503
+
504
+ @nn.compact
505
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
506
+ """
507
+ This function takes projected key, value states from a single input token and concatenates the states to cached
508
+ states from previous steps. This function is slighly adapted from the official Flax repository:
509
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
510
+ """
511
+ # detect if we're initializing by absence of existing cache data.
512
+ is_initialized = self.has_variable("cache", "cached_key")
513
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
514
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
515
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
516
+
517
+ if is_initialized:
518
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
519
+ # update key, value caches with our new 1d spatial slices
520
+ cur_index = cache_index.value
521
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
522
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
523
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
524
+ cached_key.value = key
525
+ cached_value.value = value
526
+ num_updated_cache_vectors = query.shape[1]
527
+ cache_index.value = cache_index.value + num_updated_cache_vectors
528
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
529
+ pad_mask = jnp.broadcast_to(
530
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
531
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
532
+ )
533
+ attention_mask = combine_masks(pad_mask, attention_mask)
534
+ return key, value, attention_mask
535
+
536
+ def __call__(
537
+ self,
538
+ hidden_states,
539
+ attention_mask,
540
+ position_ids,
541
+ deterministic: bool = True,
542
+ init_cache: bool = False,
543
+ output_attentions: bool = False,
544
+ fcm_mask=None,
545
+ ):
546
+ xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
547
+
548
+ xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
549
+ xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
550
+ xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
551
+
552
+ xq = self._split_heads(xq)
553
+ xk = self._split_heads(xk)
554
+ xv = self._split_heads(xv)
555
+
556
+ freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
557
+
558
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
559
+
560
+ dropout_rng = None
561
+ if not deterministic and self.config.attn_pdrop > 0.0:
562
+ dropout_rng = self.make_rng("dropout")
563
+
564
+ if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache):
565
+ # doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory
566
+
567
+ # attention mask without nxn materlization, blockwise_attn will handle the rest
568
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
569
+ # transform boolean mask into float mask
570
+ attention_bias = lax.select(
571
+ attention_mask > 0,
572
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
573
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
574
+ )
575
+ attn_weights = None
576
+ attn_output = blockwise_attn(
577
+ xq,
578
+ xk,
579
+ xv,
580
+ bias=attention_bias,
581
+ deterministic=deterministic,
582
+ dropout_rng=dropout_rng,
583
+ attn_pdrop=self.config.attn_pdrop,
584
+ causal=True,
585
+ query_chunk_size=self.config.scan_query_chunk_size,
586
+ key_chunk_size=self.config.scan_key_chunk_size,
587
+ dtype=self.dtype,
588
+ policy=get_gradient_checkpoint_policy('nothing_saveable'),
589
+ precision=self.precision,
590
+ float32_logits=True,
591
+ prevent_cse=True,
592
+ )
593
+ attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
594
+ else:
595
+ query_length, key_length = xq.shape[1], xk.shape[1]
596
+
597
+ if self.has_variable("cache", "cached_key"):
598
+ mask_shift = self.variables["cache"]["cache_index"]
599
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
600
+ causal_mask = lax.dynamic_slice(
601
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
602
+ )
603
+ else:
604
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
605
+
606
+ batch_size = hidden_states.shape[0]
607
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
608
+
609
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
610
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
611
+
612
+ # During fast autoregressive decoding, we feed one position at a time,
613
+ # and cache the keys and values step by step.
614
+ if self.has_variable("cache", "cached_key") or init_cache:
615
+ xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
616
+
617
+ # transform boolean mask into float mask
618
+ attention_bias = lax.select(
619
+ attention_mask > 0,
620
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
621
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
622
+ )
623
+ attn_weights = dot_product_attention_weights(
624
+ xq,
625
+ xk,
626
+ bias=attention_bias,
627
+ dropout_rng=dropout_rng,
628
+ dropout_rate=self.config.attn_pdrop,
629
+ deterministic=deterministic,
630
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
631
+ precision=self.precision,
632
+ )
633
+ attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
634
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
635
+
636
+ attn_output = self._merge_heads(attn_output)
637
+ attn_output = self.wo(attn_output)
638
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
639
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
640
+ return outputs
641
+
642
+
643
+ class FlaxLLaMAMLP(nn.Module):
644
+ config: LLaMAConfig
645
+ dtype: jnp.dtype=jnp.float32
646
+ param_dtype: jnp.dtype=jnp.float32
647
+ precision: Optional[Union[jax.lax.Precision, str]]=None
648
+
649
+ def setup(self) -> None:
650
+ config = self.config
651
+
652
+ self.w1 = nn.Dense(
653
+ config.intermediate_size,
654
+ dtype=self.dtype,
655
+ param_dtype=self.param_dtype,
656
+ use_bias=False,
657
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
658
+ precision=self.precision,
659
+ )
660
+ self.w2 = nn.Dense(
661
+ config.hidden_size,
662
+ dtype=self.dtype,
663
+ param_dtype=self.param_dtype,
664
+ use_bias=False,
665
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
666
+ precision=self.precision,
667
+ )
668
+ self.w3 = nn.Dense(
669
+ config.intermediate_size,
670
+ dtype=self.dtype,
671
+ param_dtype=self.param_dtype,
672
+ use_bias=False,
673
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
674
+ precision=self.precision,
675
+ )
676
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
677
+
678
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
679
+ x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
680
+ x = self.dropout(x, deterministic=deterministic)
681
+ return x
682
+
683
+
684
+ class FlaxLLaMABlock(nn.Module):
685
+ config: LLaMAConfig
686
+ dtype: jnp.dtype=jnp.float32
687
+ param_dtype: jnp.dtype=jnp.float32
688
+ precision: Optional[Union[jax.lax.Precision, str]]=None
689
+
690
+ def setup(self) -> None:
691
+ attention_module = FlaxLLaMAAttention
692
+ mlp_module = FlaxLLaMAMLP
693
+ if self.config.remat_attention != '':
694
+ attention_module = remat(
695
+ FlaxLLaMAAttention, static_argnums=(3, 4, 5),
696
+ policy=get_gradient_checkpoint_policy(self.config.remat_attention),
697
+ prevent_cse=True,
698
+ )
699
+ if self.config.remat_mlp != '':
700
+ mlp_module = remat(
701
+ FlaxLLaMAMLP, static_argnums=(1,),
702
+ policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
703
+ prevent_cse=True,
704
+ )
705
+
706
+ self.attention = attention_module(
707
+ self.config,
708
+ dtype=self.dtype,
709
+ param_dtype=self.param_dtype,
710
+ precision=self.precision,
711
+ )
712
+ self.feed_forward = mlp_module(
713
+ self.config,
714
+ dtype=self.dtype,
715
+ param_dtype=self.param_dtype,
716
+ precision=self.precision,
717
+ )
718
+ self.attention_norm = RMSNorm(
719
+ self.config.hidden_size,
720
+ eps=self.config.rms_norm_eps,
721
+ dtype=self.dtype,
722
+ param_dtype=self.param_dtype,
723
+ )
724
+ self.ffn_norm = RMSNorm(
725
+ self.config.hidden_size,
726
+ eps=self.config.rms_norm_eps,
727
+ dtype=self.dtype,
728
+ param_dtype=self.param_dtype,
729
+ )
730
+
731
+ def __call__(
732
+ self,
733
+ hidden_states,
734
+ attention_mask=None,
735
+ position_ids=None,
736
+ deterministic: bool = True,
737
+ init_cache: bool = False,
738
+ output_attentions: bool = False,
739
+ fcm_mask: Optional[jnp.ndarray] = None,
740
+ ):
741
+ attn_outputs = self.attention(
742
+ self.attention_norm(hidden_states),
743
+ attention_mask,
744
+ position_ids,
745
+ deterministic,
746
+ init_cache,
747
+ output_attentions,
748
+ fcm_mask,
749
+ )
750
+ attn_output = attn_outputs[0]
751
+ hidden_states = hidden_states + attn_output
752
+
753
+ feed_forward_input = self.ffn_norm(hidden_states)
754
+
755
+ if self.config.scan_mlp:
756
+ feed_forward_hidden_states = blockwise_ffn(
757
+ self.feed_forward,
758
+ feed_forward_input,
759
+ self.config.scan_mlp_chunk_size,
760
+ deterministic,
761
+ )
762
+ else:
763
+ feed_forward_hidden_states = self.feed_forward(
764
+ feed_forward_input,
765
+ deterministic,
766
+ )
767
+ feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp"))
768
+
769
+ hidden_states = hidden_states + feed_forward_hidden_states
770
+
771
+ return (hidden_states,) + attn_outputs[1:]
772
+
773
+
774
+ class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
775
+ """
776
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
777
+ models.
778
+ """
779
+
780
+ config_class = LLaMAConfig
781
+ base_model_prefix = "transformer"
782
+ module_class: nn.Module = None
783
+
784
+ def __init__(
785
+ self,
786
+ config: LLaMAConfig,
787
+ input_shape: Tuple = (1, 1),
788
+ seed: int = 0,
789
+ dtype: jnp.dtype = jnp.float32,
790
+ _do_init: bool = True,
791
+ **kwargs,
792
+ ):
793
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
794
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
795
+
796
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
797
+ # init input tensors
798
+ input_ids = jnp.zeros(input_shape, dtype="i4")
799
+ attention_mask = jnp.ones_like(input_ids)
800
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
801
+ params_rng, dropout_rng = jax.random.split(rng)
802
+ rngs = {"params": params_rng, "dropout": dropout_rng}
803
+
804
+ if self.config.add_cross_attention:
805
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
806
+ encoder_attention_mask = attention_mask
807
+ module_init_outputs = self.module.init(
808
+ rngs,
809
+ input_ids,
810
+ attention_mask,
811
+ position_ids,
812
+ encoder_hidden_states,
813
+ encoder_attention_mask,
814
+ return_dict=False,
815
+ )
816
+ else:
817
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
818
+
819
+ random_params = module_init_outputs["params"]
820
+
821
+ if params is not None:
822
+ random_params = flatten_dict(unfreeze(random_params))
823
+ params = flatten_dict(unfreeze(params))
824
+ for missing_key in self._missing_keys:
825
+ params[missing_key] = random_params[missing_key]
826
+ self._missing_keys = set()
827
+ return freeze(unflatten_dict(params))
828
+ else:
829
+ return random_params
830
+
831
+ def init_cache(self, batch_size, max_length):
832
+ r"""
833
+ Args:
834
+ batch_size (`int`):
835
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
836
+ max_length (`int`):
837
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
838
+ cache.
839
+ """
840
+ # init input variables to retrieve cache
841
+ input_ids = jnp.ones((batch_size, max_length))
842
+ attention_mask = jnp.ones_like(input_ids)
843
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
844
+
845
+ init_variables = self.module.init(
846
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
847
+ )
848
+ return init_variables["cache"]
849
+
850
+ @add_start_docstrings_to_model_forward("")
851
+ def __call__(
852
+ self,
853
+ input_ids,
854
+ attention_mask=None,
855
+ position_ids=None,
856
+ params: dict = None,
857
+ past_key_values: dict = None,
858
+ dropout_rng: jax.random.PRNGKey = None,
859
+ train: bool = False,
860
+ output_attentions: Optional[bool] = None,
861
+ output_hidden_states: Optional[bool] = None,
862
+ return_dict: Optional[bool] = None,
863
+ ):
864
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
865
+ output_hidden_states = (
866
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
867
+ )
868
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
869
+
870
+ batch_size, sequence_length = input_ids.shape
871
+
872
+ if position_ids is None:
873
+ if past_key_values is not None:
874
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
875
+
876
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
877
+
878
+ if attention_mask is None:
879
+ attention_mask = jnp.ones((batch_size, sequence_length))
880
+
881
+ # Handle any PRNG if needed
882
+ rngs = {}
883
+ if dropout_rng is not None:
884
+ rngs["dropout"] = dropout_rng
885
+
886
+ inputs = {"params": params or self.params}
887
+
888
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
889
+ if past_key_values:
890
+ inputs["cache"] = past_key_values
891
+ mutable = ["cache"]
892
+ else:
893
+ mutable = False
894
+
895
+ outputs = self.module.apply(
896
+ inputs,
897
+ jnp.array(input_ids, dtype="i4"),
898
+ jnp.array(attention_mask, dtype="i4"),
899
+ jnp.array(position_ids, dtype="i4"),
900
+ not train,
901
+ False,
902
+ output_attentions,
903
+ output_hidden_states,
904
+ return_dict,
905
+ rngs=rngs,
906
+ mutable=mutable,
907
+ )
908
+
909
+ # add updated cache to model output
910
+ if past_key_values is not None and return_dict:
911
+ outputs, past_key_values = outputs
912
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
913
+ return outputs
914
+ elif past_key_values is not None and not return_dict:
915
+ outputs, past_key_values = outputs
916
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
917
+
918
+ return outputs
919
+
920
+
921
+ class FlaxLLaMABlockCollection(nn.Module):
922
+ config: LLaMAConfig
923
+ dtype: jnp.dtype = jnp.float32
924
+ param_dtype: jnp.dtype=jnp.float32
925
+ precision: Optional[Union[jax.lax.Precision, str]]=None
926
+
927
+ def setup(self):
928
+ block = FlaxLLaMABlock
929
+ if self.config.remat_block != '':
930
+ block = remat(
931
+ FlaxLLaMABlock, static_argnums=(3, 4, 5),
932
+ policy=get_gradient_checkpoint_policy(self.config.remat_block)
933
+ )
934
+ self.blocks = [
935
+ block(
936
+ self.config,
937
+ name=str(i),
938
+ dtype=self.dtype,
939
+ param_dtype=self.param_dtype,
940
+ precision=self.precision
941
+ ) for i in range(self.config.num_hidden_layers)
942
+ ]
943
+
944
+ def __call__(
945
+ self,
946
+ hidden_states,
947
+ attention_mask=None,
948
+ position_ids=None,
949
+ deterministic: bool = True,
950
+ init_cache: bool = False,
951
+ output_attentions: bool = False,
952
+ output_hidden_states: bool = False,
953
+ return_dict: bool = True,
954
+ ):
955
+ all_attentions = () if output_attentions else None
956
+ all_hidden_states = () if output_hidden_states else None
957
+
958
+ if not deterministic and self.config.fcm_max_ratio > 0:
959
+ # Apply forgetful causal mask
960
+ batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
961
+ fcm_ratio = jax.random.uniform(
962
+ self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
963
+ minval=self.config.fcm_min_ratio,
964
+ maxval=self.config.fcm_max_ratio
965
+ )
966
+ fcm_mask = jax.random.uniform(
967
+ self.make_rng('fcm'),
968
+ shape=(batch_size, 1, 1, seq_length)
969
+ ) > fcm_ratio
970
+ fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
971
+ fcm_mask = fcm_mask.astype('bool')
972
+ else:
973
+ fcm_mask = None
974
+
975
+ for block in self.blocks:
976
+ if output_hidden_states:
977
+ all_hidden_states += (hidden_states,)
978
+
979
+ layer_outputs = block(
980
+ hidden_states,
981
+ attention_mask,
982
+ position_ids,
983
+ deterministic,
984
+ init_cache,
985
+ output_attentions,
986
+ fcm_mask,
987
+ )
988
+ hidden_states = layer_outputs[0]
989
+
990
+ if output_attentions:
991
+ all_attentions += (layer_outputs[1],)
992
+
993
+ # this contains possible `None` values - `FlaxGPTJModule` will filter them out
994
+ outputs = (hidden_states, all_hidden_states, all_attentions)
995
+
996
+ return outputs
997
+
998
+
999
+ class FlaxLLaMAModule(nn.Module):
1000
+ config: LLaMAConfig
1001
+ dtype: jnp.dtype = jnp.float32
1002
+ param_dtype: jnp.dtype=jnp.float32
1003
+ precision: Optional[Union[jax.lax.Precision, str]]=None
1004
+
1005
+ def setup(self):
1006
+ self.embed_dim = self.config.hidden_size
1007
+
1008
+ self.wte = nn.Embed(
1009
+ self.config.vocab_size,
1010
+ self.config.hidden_size,
1011
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
1012
+ dtype=self.dtype,
1013
+ param_dtype=self.param_dtype,
1014
+ )
1015
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
1016
+ self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
1017
+ self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
1018
+
1019
+ def __call__(
1020
+ self,
1021
+ input_ids,
1022
+ attention_mask,
1023
+ position_ids,
1024
+ deterministic=True,
1025
+ init_cache: bool = False,
1026
+ output_attentions: bool = False,
1027
+ output_hidden_states: bool = False,
1028
+ return_dict: bool = True,
1029
+ ):
1030
+ input_embeds = self.wte(input_ids.astype("i4"))
1031
+
1032
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
1033
+
1034
+ outputs = self.h(
1035
+ hidden_states,
1036
+ attention_mask,
1037
+ position_ids=position_ids,
1038
+ deterministic=deterministic,
1039
+ init_cache=init_cache,
1040
+ output_attentions=output_attentions,
1041
+ output_hidden_states=output_hidden_states,
1042
+ return_dict=return_dict,
1043
+ )
1044
+
1045
+ hidden_states = outputs[0]
1046
+ hidden_states = self.ln_f(hidden_states)
1047
+
1048
+ if output_hidden_states:
1049
+ all_hidden_states = outputs[1] + (hidden_states,)
1050
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
1051
+ else:
1052
+ outputs = (hidden_states,) + outputs[1:]
1053
+
1054
+ if not return_dict:
1055
+ return tuple(v for v in outputs if v is not None)
1056
+
1057
+ return FlaxBaseModelOutput(
1058
+ last_hidden_state=hidden_states,
1059
+ hidden_states=outputs[1],
1060
+ attentions=outputs[-1],
1061
+ )
1062
+
1063
+ @add_start_docstrings("", "")
1064
+ class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel):
1065
+ module_class = FlaxLLaMAModule
1066
+
1067
+ # append_call_sample_docstring(
1068
+ # FlaxLLaMAModel,
1069
+ # _TOKENIZER_FOR_DOC,
1070
+ # _CHECKPOINT_FOR_DOC,
1071
+ # FlaxCausalLMOutput,
1072
+ # _CONFIG_FOR_DOC,
1073
+ # )
1074
+
1075
+ class FlaxLLaMAForCausalLMModule(nn.Module):
1076
+ config: LLaMAConfig
1077
+ dtype: jnp.dtype = jnp.float32
1078
+ param_dtype: jnp.dtype=jnp.float32
1079
+ precision: Optional[Union[jax.lax.Precision, str]]=None
1080
+
1081
+ def setup(self):
1082
+ self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
1083
+ self.lm_head = nn.Dense(
1084
+ self.config.vocab_size,
1085
+ dtype=self.dtype,
1086
+ param_dtype=self.param_dtype,
1087
+ use_bias=False,
1088
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
1089
+ precision=self.precision,
1090
+ )
1091
+
1092
+ def __call__(
1093
+ self,
1094
+ input_ids,
1095
+ attention_mask=None,
1096
+ position_ids=None,
1097
+ deterministic: bool = True,
1098
+ init_cache: bool = False,
1099
+ output_attentions: bool = False,
1100
+ output_hidden_states: bool = False,
1101
+ return_dict: bool = True,
1102
+ ):
1103
+ batch_size, seq_length = input_ids.shape
1104
+ if attention_mask is None:
1105
+ attention_mask = jnp.ones_like(input_ids)
1106
+ if position_ids is None:
1107
+ position_ids = jnp.broadcast_to(
1108
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1109
+ (batch_size, seq_length)
1110
+ )
1111
+ outputs = self.transformer(
1112
+ input_ids,
1113
+ attention_mask,
1114
+ position_ids,
1115
+ deterministic=deterministic,
1116
+ init_cache=init_cache,
1117
+ output_attentions=output_attentions,
1118
+ output_hidden_states=output_hidden_states,
1119
+ return_dict=return_dict,
1120
+ )
1121
+
1122
+ hidden_states = outputs[0]
1123
+
1124
+ if self.config.tie_word_embeddings:
1125
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
1126
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
1127
+ else:
1128
+ lm_logits = self.lm_head(hidden_states)
1129
+
1130
+ if not return_dict:
1131
+ return (lm_logits,) + outputs[1:]
1132
+
1133
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
1134
+
1135
+
1136
+ @add_start_docstrings("", "")
1137
+ class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
1138
+ module_class = FlaxLLaMAForCausalLMModule
1139
+
1140
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
1141
+ # initializing the cache
1142
+ batch_size, seq_length = input_ids.shape
1143
+
1144
+ past_key_values = self.init_cache(batch_size, max_length)
1145
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1146
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
1147
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1148
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1149
+ if attention_mask is not None:
1150
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1151
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1152
+ else:
1153
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1154
+
1155
+ return {
1156
+ "past_key_values": past_key_values,
1157
+ "attention_mask": extended_attention_mask,
1158
+ "position_ids": position_ids,
1159
+ }
1160
+
1161
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1162
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1163
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1164
+ return model_kwargs
1165
+
1166
+ # append_call_sample_docstring(
1167
+ # FlaxGPTJForCausalLM,
1168
+ # _TOKENIZER_FOR_DOC,
1169
+ # _CHECKPOINT_FOR_DOC,
1170
+ # FlaxCausalLMOutput,
1171
+ # _CONFIG_FOR_DOC,
1172
+ # )
1173
+
1174
+
1175
+
1176
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
1177
+
1178
+ PRETRAINED_VOCAB_FILES_MAP = {}
1179
+
1180
+
1181
+ class LLaMATokenizer(PreTrainedTokenizer):
1182
+ """
1183
+ Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding.
1184
+ Args:
1185
+ vocab_file (`str`):
1186
+ Path to the vocabulary file.
1187
+ """
1188
+
1189
+ vocab_files_names = VOCAB_FILES_NAMES
1190
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
1191
+ model_input_names = ["input_ids", "attention_mask"]
1192
+
1193
+ def __init__(
1194
+ self,
1195
+ vocab_file,
1196
+ unk_token="<unk>",
1197
+ bos_token="<s>",
1198
+ eos_token="</s>",
1199
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
1200
+ add_bos_token=False,
1201
+ add_eos_token=False,
1202
+ **kwargs,
1203
+ ):
1204
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
1205
+ super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
1206
+ self.vocab_file = vocab_file
1207
+ self.add_bos_token = add_bos_token
1208
+ self.add_eos_token = add_eos_token
1209
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
1210
+
1211
+ with tempfile.NamedTemporaryFile() as tfile:
1212
+ with open_file(self.vocab_file, 'rb') as fin:
1213
+ tfile.write(fin.read())
1214
+ tfile.flush()
1215
+ tfile.seek(0)
1216
+ self.sp_model.Load(tfile.name)
1217
+ """ Initialisation"""
1218
+ self.add_special_tokens(dict(
1219
+ unk_token=unk_token,
1220
+ bos_token=bos_token,
1221
+ eos_token=eos_token,
1222
+ ))
1223
+ self.pad_token_id = self.unk_token_id
1224
+
1225
+ @property
1226
+ def vocab_size(self):
1227
+ """Returns vocab size"""
1228
+ return self.sp_model.get_piece_size()
1229
+
1230
+ @property
1231
+ def bos_token_id(self) -> Optional[int]:
1232
+ return self.sp_model.bos_id()
1233
+
1234
+ @property
1235
+ def eos_token_id(self) -> Optional[int]:
1236
+ return self.sp_model.eos_id()
1237
+
1238
+ def get_vocab(self):
1239
+ """Returns vocab as a dict"""
1240
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
1241
+ vocab.update(self.added_tokens_encoder)
1242
+ return vocab
1243
+
1244
+ def _tokenize(self, text):
1245
+ """Returns a tokenized string."""
1246
+ return self.sp_model.encode(text, out_type=str)
1247
+
1248
+ def _convert_token_to_id(self, token):
1249
+ """Converts a token (str) in an id using the vocab."""
1250
+ return self.sp_model.piece_to_id(token)
1251
+
1252
+ def _convert_id_to_token(self, index):
1253
+ """Converts an index (integer) in a token (str) using the vocab."""
1254
+ token = self.sp_model.IdToPiece(index)
1255
+ return token
1256
+
1257
+ def convert_tokens_to_string(self, tokens):
1258
+ """Converts a sequence of tokens (string) in a single string."""
1259
+ current_sub_tokens = []
1260
+ out_string = ""
1261
+ prev_is_special = False
1262
+ for token in tokens:
1263
+ # make sure that special tokens are not decoded using sentencepiece model
1264
+ if token in self.all_special_tokens:
1265
+ if not prev_is_special:
1266
+ out_string += " "
1267
+ out_string += self.sp_model.decode(current_sub_tokens) + token
1268
+ prev_is_special = True
1269
+ current_sub_tokens = []
1270
+ else:
1271
+ current_sub_tokens.append(token)
1272
+ prev_is_special = False
1273
+ out_string += self.sp_model.decode(current_sub_tokens)
1274
+ return out_string.strip()
1275
+
1276
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
1277
+ """
1278
+ Save the vocabulary and special tokens file to a directory.
1279
+ Args:
1280
+ save_directory (`str`):
1281
+ The directory in which to save the vocabulary.
1282
+ Returns:
1283
+ `Tuple(str)`: Paths to the files saved.
1284
+ """
1285
+ if not os.path.isdir(save_directory):
1286
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
1287
+ return
1288
+ out_vocab_file = os.path.join(
1289
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
1290
+ )
1291
+
1292
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
1293
+ copyfile(self.vocab_file, out_vocab_file)
1294
+ elif not os.path.isfile(self.vocab_file):
1295
+ with open(out_vocab_file, "wb") as fi:
1296
+ content_spiece_model = self.sp_model.serialized_model_proto()
1297
+ fi.write(content_spiece_model)
1298
+
1299
+ return (out_vocab_file,)
1300
+
1301
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
1302
+ if self.add_bos_token:
1303
+ bos_token_ids = [self.bos_token_id]
1304
+ else:
1305
+ bos_token_ids = []
1306
+
1307
+ output = bos_token_ids + token_ids_0
1308
+
1309
+ if token_ids_1 is not None:
1310
+ output = output + token_ids_1
1311
+
1312
+ if self.add_eos_token:
1313
+ output = output + [self.eos_token_id]
1314
+
1315
+ return output
1316
+
1317
+ def get_special_tokens_mask(
1318
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
1319
+ ) -> List[int]:
1320
+ """
1321
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
1322
+ special tokens using the tokenizer `prepare_for_model` method.
1323
+ Args:
1324
+ token_ids_0 (`List[int]`):
1325
+ List of IDs.
1326
+ token_ids_1 (`List[int]`, *optional*):
1327
+ Optional second list of IDs for sequence pairs.
1328
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
1329
+ Whether or not the token list is already formatted with special tokens for the model.
1330
+ Returns:
1331
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
1332
+ """
1333
+ if already_has_special_tokens:
1334
+ return super().get_special_tokens_mask(
1335
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
1336
+ )
1337
+
1338
+ if token_ids_1 is None:
1339
+ return [1] + ([0] * len(token_ids_0)) + [1]
1340
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
1341
+
1342
+ def create_token_type_ids_from_sequences(
1343
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
1344
+ ) -> List[int]:
1345
+ """
1346
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
1347
+ use of token type ids, therefore a list of zeros is returned.
1348
+ Args:
1349
+ token_ids_0 (`List[int]`):
1350
+ List of IDs.
1351
+ token_ids_1 (`List[int]`, *optional*):
1352
+ Optional second list of IDs for sequence pairs.
1353
+ Returns:
1354
+ `List[int]`: List of zeros.
1355
+ """
1356
+ eos = [self.eos_token_id]
1357
+
1358
+ if token_ids_1 is None:
1359
+ return len(token_ids_0 + eos) * [0]
1360
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
EasyLM/models/llama/llama_serve.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import mlxu
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax.experimental.pjit import pjit
10
+ from jax.sharding import PartitionSpec as PS
11
+ import optax
12
+ from transformers import GenerationConfig, FlaxLogitsProcessorList
13
+
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+ from EasyLM.serving import LMServer
16
+ from EasyLM.jax_utils import (
17
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
18
+ set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
19
+ with_sharding_constraint, FlaxTemperatureLogitsWarper
20
+ )
21
+ from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
22
+
23
+
24
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
25
+ seed=42,
26
+ initialize_jax_distributed=False,
27
+ mesh_dim='1,-1,1',
28
+ dtype='bf16',
29
+ input_length=1024,
30
+ seq_length=2048,
31
+ top_k=50,
32
+ top_p=1.0,
33
+ do_sample=True,
34
+ num_beams=1,
35
+ add_bos_token=True,
36
+ load_llama_config='',
37
+ load_checkpoint='',
38
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
39
+ lm_server=LMServer.get_default_config(),
40
+ jax_distributed=JaxDistributedConfig.get_default_config(),
41
+ )
42
+
43
+
44
+ def main(argv):
45
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
46
+ set_random_seed(FLAGS.seed)
47
+
48
+ prefix_tokenizer = LLaMAConfig.get_tokenizer(
49
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
50
+ )
51
+ tokenizer = LLaMAConfig.get_tokenizer(
52
+ FLAGS.tokenizer, truncation_side='right', padding_side='right'
53
+ )
54
+
55
+ with jax.default_device(jax.devices("cpu")[0]):
56
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
57
+ _, params = StreamingCheckpointer.load_trainstate_checkpoint(
58
+ FLAGS.load_checkpoint, disallow_trainstate=True
59
+ )
60
+
61
+ hf_model = FlaxLLaMAForCausalLM(
62
+ llama_config,
63
+ input_shape=(1, FLAGS.seq_length),
64
+ seed=FLAGS.seed,
65
+ _do_init=False
66
+ )
67
+
68
+ model_ps = match_partition_rules(
69
+ LLaMAConfig.get_partition_rules(), params
70
+ )
71
+ shard_fns, _ = make_shard_and_gather_fns(
72
+ model_ps, get_float_dtype_by_name(FLAGS.dtype)
73
+ )
74
+
75
+ @partial(
76
+ pjit,
77
+ in_shardings=(model_ps, PS(), PS()),
78
+ out_shardings=(PS(), PS(), PS())
79
+ )
80
+ def forward_loglikelihood(params, rng, batch):
81
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
82
+ rng_generator = JaxRNG(rng)
83
+ input_tokens = batch['input_tokens']
84
+ output_tokens = batch['output_tokens']
85
+ input_mask = batch['input_mask']
86
+ output_mask = batch['output_mask']
87
+
88
+ logits = hf_model.module.apply(
89
+ params, input_tokens, attention_mask=input_mask,
90
+ deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
91
+ ).logits
92
+ # if llama_config.n_real_tokens is not None:
93
+ # logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
94
+ loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
95
+ logits, output_tokens
96
+ )
97
+ loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
98
+ match_count = jnp.sum(
99
+ (jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
100
+ axis=-1
101
+ )
102
+ total = jnp.sum(output_mask, axis=-1)
103
+ is_greedy = match_count == total
104
+ return loglikelihood, is_greedy, rng_generator()
105
+
106
+
107
+ @partial(
108
+ pjit,
109
+ in_shardings=(model_ps, PS(), PS(), PS()),
110
+ out_shardings=(PS(), PS())
111
+ )
112
+ def forward_generate(params, rng, batch, temperature):
113
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
114
+ rng_generator = JaxRNG(rng)
115
+ output = hf_model.generate(
116
+ batch['input_tokens'],
117
+ attention_mask=batch['attention_mask'],
118
+ params=params['params'],
119
+ prng_key=rng_generator(),
120
+ logits_processor=FlaxLogitsProcessorList(
121
+ [FlaxTemperatureLogitsWarper(temperature)]
122
+ ),
123
+ generation_config=GenerationConfig(
124
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
125
+ pad_token_id=tokenizer.eos_token_id,
126
+ bos_token_id=tokenizer.bos_token_id,
127
+ eos_token_id=tokenizer.eos_token_id,
128
+ do_sample=FLAGS.do_sample,
129
+ num_beams=FLAGS.num_beams,
130
+ top_k=FLAGS.top_k,
131
+ top_p=FLAGS.top_p,
132
+ )
133
+ ).sequences[:, batch['input_tokens'].shape[1]:]
134
+ return output, rng_generator()
135
+
136
+ @partial(
137
+ pjit,
138
+ in_shardings=(model_ps, PS(), PS()),
139
+ out_shardings=(PS(), PS())
140
+ )
141
+ def forward_greedy_generate(params, rng, batch):
142
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
143
+ rng_generator = JaxRNG(rng)
144
+ output = hf_model.generate(
145
+ batch['input_tokens'],
146
+ attention_mask=batch['attention_mask'],
147
+ params=params['params'],
148
+ prng_key=rng_generator(),
149
+ generation_config=GenerationConfig(
150
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
151
+ pad_token_id=tokenizer.eos_token_id,
152
+ bos_token_id=tokenizer.bos_token_id,
153
+ eos_token_id=tokenizer.eos_token_id,
154
+ do_sample=False,
155
+ num_beams=1,
156
+ )
157
+ ).sequences[:, batch['input_tokens'].shape[1]:]
158
+ return output, rng_generator()
159
+
160
+ mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
161
+ with mesh:
162
+ params = tree_apply(shard_fns, params)
163
+ sharded_rng = next_rng()
164
+
165
+ class ModelServer(LMServer):
166
+
167
+ @staticmethod
168
+ def loglikelihood(prefix_text, text):
169
+ nonlocal sharded_rng
170
+ prefix = prefix_tokenizer(
171
+ prefix_text,
172
+ padding='max_length',
173
+ truncation=True,
174
+ max_length=FLAGS.input_length,
175
+ return_tensors='np',
176
+ )
177
+ inputs = tokenizer(
178
+ text,
179
+ padding='max_length',
180
+ truncation=True,
181
+ max_length=FLAGS.seq_length - FLAGS.input_length,
182
+ return_tensors='np',
183
+ )
184
+ output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
185
+ bos_tokens = np.full(
186
+ (output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
187
+ )
188
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
189
+ input_mask = np.concatenate(
190
+ [prefix.attention_mask, inputs.attention_mask], axis=1
191
+ )
192
+ if FLAGS.add_bos_token:
193
+ bos_mask = np.ones_like(input_mask[:, :1])
194
+ else:
195
+ bos_mask = np.zeros_like(input_mask[:, :1])
196
+
197
+ input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
198
+ output_mask = np.concatenate(
199
+ [np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
200
+ )
201
+ batch = dict(
202
+ input_tokens=input_tokens,
203
+ output_tokens=output_tokens,
204
+ input_mask=input_mask,
205
+ output_mask=output_mask,
206
+ )
207
+ with mesh:
208
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
209
+ params, sharded_rng, batch
210
+ )
211
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
212
+ return loglikelihood, is_greedy
213
+
214
+ @staticmethod
215
+ def loglikelihood_rolling(text):
216
+ nonlocal sharded_rng
217
+ inputs = tokenizer(
218
+ text,
219
+ padding='longest',
220
+ truncation=False,
221
+ max_length=np.iinfo(np.int32).max,
222
+ return_tensors='np',
223
+ )
224
+ batch_size = inputs.input_ids.shape[0]
225
+ output_tokens = inputs.input_ids
226
+ attention_mask = inputs.attention_mask
227
+
228
+ if output_tokens.shape[1] < FLAGS.seq_length:
229
+ padding_length = FLAGS.seq_length - output_tokens.shape[1]
230
+ pad_tokens = np.full(
231
+ (batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
232
+ )
233
+ output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
234
+ pad_mask = np.zeros(
235
+ (batch_size, padding_length), dtype=inputs.attention_mask.dtype
236
+ )
237
+ attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
238
+
239
+ bos_tokens = np.full(
240
+ (batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
241
+ )
242
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
243
+ bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
244
+ total_seq_length = output_tokens.shape[1]
245
+
246
+ total_loglikelihood = 0.0
247
+ total_is_greedy = True
248
+ # Sliding window
249
+ for i in range(0, total_seq_length, FLAGS.seq_length):
250
+ # Last window
251
+ if i + FLAGS.seq_length > total_seq_length:
252
+ last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
253
+ last_output_mask[:, :i - total_seq_length] = 0.0
254
+
255
+ batch = dict(
256
+ input_tokens=input_tokens[:, -FLAGS.seq_length:],
257
+ output_tokens=output_tokens[:, -FLAGS.seq_length:],
258
+ input_mask=attention_mask[:, -FLAGS.seq_length:],
259
+ output_mask=last_output_mask,
260
+ )
261
+
262
+ # Normal window
263
+ else:
264
+ batch = dict(
265
+ input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
266
+ output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
267
+ input_mask=attention_mask[:, i:i + FLAGS.seq_length],
268
+ output_mask=attention_mask[:, i:i + FLAGS.seq_length],
269
+ )
270
+
271
+ with mesh:
272
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
273
+ params, sharded_rng, batch
274
+ )
275
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
276
+
277
+ total_loglikelihood += loglikelihood
278
+ total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
279
+
280
+ return total_loglikelihood, total_is_greedy
281
+
282
+ @staticmethod
283
+ def generate(text, temperature):
284
+ nonlocal sharded_rng
285
+ inputs = prefix_tokenizer(
286
+ text,
287
+ padding='max_length',
288
+ truncation=True,
289
+ max_length=FLAGS.input_length,
290
+ return_tensors='np',
291
+ )
292
+ input_tokens = inputs.input_ids
293
+ input_mask = inputs.attention_mask
294
+ if FLAGS.add_bos_token:
295
+ input_tokens[:, 0] = tokenizer.bos_token_id
296
+ input_mask[:, 0] = 1
297
+ batch = dict(
298
+ input_tokens=input_tokens,
299
+ attention_mask=input_mask,
300
+ )
301
+ with mesh:
302
+ output, sharded_rng = forward_generate(
303
+ params, sharded_rng, batch, temperature
304
+ )
305
+ output = jax.device_get(output)
306
+ output_text = []
307
+ for text in list(tokenizer.batch_decode(output)):
308
+ if tokenizer.eos_token in text:
309
+ text = text.split(tokenizer.eos_token, maxsplit=1)[0]
310
+ output_text.append(text)
311
+
312
+ return output_text
313
+
314
+ @staticmethod
315
+ def greedy_until(prefix_text, until, max_length):
316
+ nonlocal sharded_rng
317
+ all_outputs = []
318
+ for pf, ut in zip(prefix_text, until):
319
+ if isinstance(ut, str):
320
+ ut = [ut]
321
+ total_length = 0
322
+ total_generated = ''
323
+
324
+ while total_length < max_length:
325
+ pf_tokens = tokenizer(
326
+ pf,
327
+ padding=False,
328
+ truncation=False,
329
+ max_length=np.iinfo(np.int32).max,
330
+ return_tensors='np',
331
+ )
332
+ input_tokens = pf_tokens.input_ids
333
+ attention_mask = pf_tokens.attention_mask
334
+
335
+ if input_tokens.shape[1] < FLAGS.input_length:
336
+ extra = FLAGS.input_length - input_tokens.shape[1]
337
+ pad_tokens = np.full(
338
+ (1, extra), tokenizer.pad_token_id, dtype=np.int32
339
+ )
340
+ input_tokens = np.concatenate(
341
+ [pad_tokens, input_tokens], axis=1
342
+ )
343
+ pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
344
+ attention_mask = np.concatenate(
345
+ [pad_attention, attention_mask], axis=1
346
+ )
347
+ elif input_tokens.shape[1] > FLAGS.input_length:
348
+ input_tokens = input_tokens[:, -FLAGS.input_length:]
349
+ attention_mask = attention_mask[:, -FLAGS.input_length:]
350
+
351
+ if FLAGS.add_bos_token:
352
+ input_tokens[:, 0] = tokenizer.bos_token_id
353
+ attention_mask[:, 0] = 1
354
+
355
+ batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
356
+
357
+ with mesh:
358
+ output, sharded_rng = forward_greedy_generate(
359
+ params, sharded_rng, batch
360
+ )
361
+ output = jax.device_get(output)
362
+
363
+ total_length += output.shape[1]
364
+ output_text = tokenizer.batch_decode(output)[0]
365
+ total_generated = total_generated + output_text
366
+ pf = pf + output_text
367
+
368
+ done = False
369
+ for s in ut:
370
+ if s in total_generated:
371
+ total_generated = total_generated.split(s, maxsplit=1)[0]
372
+ done = True
373
+ if done:
374
+ break
375
+
376
+ all_outputs.append(total_generated)
377
+
378
+ return all_outputs
379
+
380
+
381
+ server = ModelServer(FLAGS.lm_server)
382
+ server.run()
383
+
384
+
385
+ if __name__ == "__main__":
386
+ mlxu.run(main)
EasyLM/models/llama/llama_train.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ from tqdm import tqdm, trange
5
+ import numpy as np
6
+ import mlxu
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax.experimental.pjit import pjit
11
+ from jax.sharding import PartitionSpec as PS
12
+ from flax.training.train_state import TrainState
13
+
14
+ from EasyLM.data import DatasetFactory
15
+ from EasyLM.checkpoint import StreamingCheckpointer
16
+ from EasyLM.optimizers import OptimizerFactory
17
+ from EasyLM.jax_utils import (
18
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
19
+ cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
+ set_random_seed, average_metrics, get_weight_decay_mask,
21
+ make_shard_and_gather_fns, with_sharding_constraint,
22
+ )
23
+ from EasyLM.models.llama.llama_model import (
24
+ LLaMAConfig, FlaxLLaMAForCausalLMModule
25
+ )
26
+
27
+
28
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
29
+ seed=42,
30
+ mesh_dim='1,-1,1',
31
+ dtype='fp32',
32
+ param_dtype='fp32',
33
+ total_steps=10000,
34
+ load_llama_config='',
35
+ update_llama_config='',
36
+ load_checkpoint='',
37
+ load_dataset_state='',
38
+ log_freq=50,
39
+ save_model_freq=0,
40
+ save_milestone_freq=0,
41
+ eval_freq=0,
42
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
43
+ train_dataset=DatasetFactory.get_default_config(),
44
+ eval_dataset=DatasetFactory.get_default_config(),
45
+ optimizer=OptimizerFactory.get_default_config(),
46
+ checkpointer=StreamingCheckpointer.get_default_config(),
47
+ llama=LLaMAConfig.get_default_config(),
48
+ logger=mlxu.WandBLogger.get_default_config(),
49
+ log_all_worker=False,
50
+ jax_distributed=JaxDistributedConfig.get_default_config(),
51
+ )
52
+
53
+
54
+ def main(argv):
55
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
56
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
57
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
58
+ logger = mlxu.WandBLogger(
59
+ config=FLAGS.logger,
60
+ variant=variant,
61
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
62
+ )
63
+ set_random_seed(FLAGS.seed)
64
+
65
+ tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
66
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
67
+ if FLAGS.load_dataset_state != '':
68
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
69
+
70
+ if FLAGS.eval_freq > 0:
71
+ eval_dataset = DatasetFactory.load_dataset(
72
+ FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
73
+ )
74
+
75
+ seq_length = dataset.seq_length
76
+
77
+ if FLAGS.load_llama_config != '':
78
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
79
+ else:
80
+ llama_config = LLaMAConfig(**FLAGS.llama)
81
+
82
+ if FLAGS.update_llama_config != '':
83
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
84
+
85
+ llama_config.update(dict(
86
+ bos_token_id=dataset.tokenizer.bos_token_id,
87
+ eos_token_id=dataset.tokenizer.eos_token_id,
88
+ ))
89
+ if llama_config.vocab_size < dataset.vocab_size:
90
+ print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
91
+ llama_config.update(dict(vocab_size=dataset.vocab_size))
92
+
93
+ model = FlaxLLaMAForCausalLMModule(
94
+ llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
95
+ )
96
+
97
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
98
+ FLAGS.optimizer,
99
+ get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
100
+ )
101
+
102
+ def create_trainstate_from_params(params):
103
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
104
+
105
+ def init_fn(rng):
106
+ rng_generator = JaxRNG(rng)
107
+ params = model.init(
108
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
109
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
110
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
111
+ rngs=rng_generator(llama_config.rng_keys()),
112
+ )
113
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
114
+
115
+ def train_step(train_state, rng, batch):
116
+ rng_generator = JaxRNG(rng)
117
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
118
+ def loss_and_accuracy(params):
119
+ logits = model.apply(
120
+ params, batch['input_tokens'], deterministic=False,
121
+ rngs=rng_generator(llama_config.rng_keys()),
122
+ ).logits
123
+ return cross_entropy_loss_and_accuracy(
124
+ logits, batch['target_tokens'], batch['loss_masks']
125
+ )
126
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
127
+ (loss, accuracy), grads = grad_fn(train_state.params)
128
+ train_state = train_state.apply_gradients(grads=grads)
129
+ metrics = dict(
130
+ loss=loss,
131
+ accuracy=accuracy,
132
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
133
+ gradient_norm=global_norm(grads),
134
+ param_norm=global_norm(train_state.params),
135
+ )
136
+ return train_state, rng_generator(), metrics
137
+
138
+ def eval_step(train_state, rng, batch):
139
+ rng_generator = JaxRNG(rng)
140
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
141
+ logits = model.apply(
142
+ train_state.params, batch['input_tokens'], deterministic=True,
143
+ rngs=rng_generator(llama_config.rng_keys()),
144
+ ).logits
145
+ loss, accuracy = cross_entropy_loss_and_accuracy(
146
+ logits, batch['target_tokens'], batch['loss_masks']
147
+ )
148
+ metrics = dict(
149
+ eval_loss=loss,
150
+ eval_accuracy=accuracy,
151
+ )
152
+ return rng_generator(), metrics
153
+
154
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
155
+ train_state_partition = match_partition_rules(
156
+ LLaMAConfig.get_partition_rules(), train_state_shapes
157
+ )
158
+
159
+ shard_fns, gather_fns = make_shard_and_gather_fns(
160
+ train_state_partition, train_state_shapes
161
+ )
162
+ checkpointer = StreamingCheckpointer(
163
+ FLAGS.checkpointer, logger.output_dir,
164
+ enable=jax.process_index() == 0,
165
+ )
166
+
167
+ sharded_init_fn = pjit(
168
+ init_fn,
169
+ in_shardings=PS(),
170
+ out_shardings=train_state_partition
171
+ )
172
+
173
+ sharded_create_trainstate_from_params = pjit(
174
+ create_trainstate_from_params,
175
+ in_shardings=(train_state_partition.params, ),
176
+ out_shardings=train_state_partition,
177
+ donate_argnums=(0, ),
178
+ )
179
+
180
+ sharded_train_step = pjit(
181
+ train_step,
182
+ in_shardings=(train_state_partition, PS(), PS()),
183
+ out_shardings=(train_state_partition, PS(), PS()),
184
+ donate_argnums=(0, 1),
185
+ )
186
+
187
+ sharded_eval_step = pjit(
188
+ eval_step,
189
+ in_shardings=(train_state_partition, PS(), PS()),
190
+ out_shardings=(PS(), PS()),
191
+ donate_argnums=(1,),
192
+ )
193
+
194
+ def save_checkpoint(train_state, milestone=False):
195
+ step = int(jax.device_get(train_state.step))
196
+ metadata = dict(
197
+ step=step,
198
+ variant=variant,
199
+ flags=flags_config_dict,
200
+ llama_config=llama_config.to_dict(),
201
+ )
202
+ checkpointer.save_all(
203
+ train_state=train_state,
204
+ gather_fns=gather_fns,
205
+ metadata=metadata,
206
+ dataset=dataset.get_state_dict(),
207
+ milestone=milestone,
208
+ )
209
+
210
+ mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
211
+ with mesh:
212
+ train_state, restored_params = None, None
213
+ if FLAGS.load_checkpoint != '':
214
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
215
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
216
+ )
217
+
218
+ if train_state is None and restored_params is None:
219
+ # Initialize from scratch
220
+ train_state = sharded_init_fn(next_rng())
221
+ elif train_state is None and restored_params is not None:
222
+ # Restore from params but initialize train_state
223
+ train_state = sharded_create_trainstate_from_params(restored_params)
224
+ del restored_params
225
+
226
+ start_step = int(jax.device_get(train_state.step))
227
+
228
+ if FLAGS.save_model_freq > 0:
229
+ save_checkpoint(train_state)
230
+
231
+ sharded_rng = next_rng()
232
+
233
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
234
+
235
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
236
+ train_state, sharded_rng, metrics = sharded_train_step(
237
+ train_state, sharded_rng, batch
238
+ )
239
+
240
+ if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
241
+ eval_metric_list = []
242
+ eval_iterator = iter(eval_dataset)
243
+ for eval_batch, _ in eval_iterator:
244
+ sharded_rng, eval_metrics = sharded_eval_step(
245
+ train_state, sharded_rng, eval_batch
246
+ )
247
+ eval_metric_list.append(eval_metrics)
248
+ metrics.update(average_metrics(eval_metric_list))
249
+
250
+ if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
251
+ log_metrics = {"step": step + 1}
252
+ log_metrics.update(metrics)
253
+ log_metrics.update(dataset_metrics)
254
+ log_metrics = jax.device_get(log_metrics)
255
+ logger.log(log_metrics)
256
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
257
+
258
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
259
+ save_checkpoint(train_state, milestone=True)
260
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
261
+ save_checkpoint(train_state)
262
+
263
+ if FLAGS.save_model_freq > 0:
264
+ save_checkpoint(train_state)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ mlxu.run(main)
EasyLM/models/roberta/__init__.py ADDED
File without changes
EasyLM/models/roberta/roberta_model.py ADDED
@@ -0,0 +1,1694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ # Modifications copyright 2022 Xinyang Geng
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Callable, Optional, Tuple
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ import numpy as np
21
+
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
26
+ from flax.linen import combine_masks, make_causal_mask
27
+ from flax.linen import partitioning as nn_partitioning
28
+ from flax.linen.attention import dot_product_attention_weights
29
+ from flax.traverse_util import flatten_dict, unflatten_dict
30
+ from jax import lax
31
+ from jax.sharding import PartitionSpec
32
+
33
+ from transformers.configuration_utils import PretrainedConfig
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxBaseModelOutputWithPooling,
37
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
38
+ FlaxCausalLMOutputWithCrossAttentions,
39
+ FlaxMaskedLMOutput,
40
+ FlaxMultipleChoiceModelOutput,
41
+ FlaxQuestionAnsweringModelOutput,
42
+ FlaxSequenceClassifierOutput,
43
+ FlaxTokenClassifierOutput,
44
+ )
45
+ from transformers.modeling_flax_utils import (
46
+ ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring,
47
+ overwrite_call_docstring
48
+ )
49
+ from transformers.utils import (
50
+ add_start_docstrings, add_start_docstrings_to_model_forward, logging
51
+ )
52
+ from transformers import AutoTokenizer
53
+
54
+ from ml_collections import ConfigDict
55
+ from ml_collections.config_dict import config_dict
56
+ from mlxu import function_args_to_config, load_pickle
57
+
58
+ from EasyLM.jax_utils import with_sharding_constraint, get_jax_mesh
59
+
60
+
61
+ """
62
+ The follow code is taken from
63
+ transformers/src/transformers/models/roberta/configuration_roberta.py
64
+ and modified to work with EasyLM.
65
+ """
66
+
67
+
68
+ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
69
+ "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json",
70
+ "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json",
71
+ "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json",
72
+ "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json",
73
+ "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json",
74
+ "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json",
75
+ }
76
+
77
+
78
+ class RobertaConfig(PretrainedConfig):
79
+ r"""
80
+ This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
81
+ used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
82
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
83
+ [roberta-base](https://huggingface.co/roberta-base) architecture.
84
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
85
+ documentation from [`PretrainedConfig`] for more information.
86
+ Args:
87
+ vocab_size (`int`, *optional*, defaults to 30522):
88
+ Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
89
+ `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
90
+ hidden_size (`int`, *optional*, defaults to 768):
91
+ Dimensionality of the encoder layers and the pooler layer.
92
+ num_hidden_layers (`int`, *optional*, defaults to 12):
93
+ Number of hidden layers in the Transformer encoder.
94
+ num_attention_heads (`int`, *optional*, defaults to 12):
95
+ Number of attention heads for each attention layer in the Transformer encoder.
96
+ intermediate_size (`int`, *optional*, defaults to 3072):
97
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
98
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
99
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
100
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
101
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
102
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
103
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
104
+ The dropout ratio for the attention probabilities.
105
+ max_position_embeddings (`int`, *optional*, defaults to 512):
106
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
107
+ just in case (e.g., 512 or 1024 or 2048).
108
+ type_vocab_size (`int`, *optional*, defaults to 2):
109
+ The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
110
+ initializer_range (`float`, *optional*, defaults to 0.02):
111
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
112
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
113
+ The epsilon used by the layer normalization layers.
114
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
115
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
116
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
117
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
118
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
119
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
120
+ use_cache (`bool`, *optional*, defaults to `True`):
121
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
122
+ relevant if `config.is_decoder=True`.
123
+ classifier_dropout (`float`, *optional*):
124
+ The dropout ratio for the classification head.
125
+ Examples:
126
+ ```python
127
+ >>> from transformers import RobertaConfig, RobertaModel
128
+ >>> # Initializing a RoBERTa configuration
129
+ >>> configuration = RobertaConfig()
130
+ >>> # Initializing a model (with random weights) from the configuration
131
+ >>> model = RobertaModel(configuration)
132
+ >>> # Accessing the model configuration
133
+ >>> configuration = model.config
134
+ ```"""
135
+ model_type = "roberta"
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_size=50265,
140
+ hidden_size=768,
141
+ num_hidden_layers=12,
142
+ num_attention_heads=12,
143
+ intermediate_size=3072,
144
+ hidden_act="gelu",
145
+ hidden_dropout_prob=0.1,
146
+ attention_probs_dropout_prob=0.1,
147
+ max_position_embeddings=514,
148
+ type_vocab_size=1,
149
+ initializer_range=0.02,
150
+ layer_norm_eps=1e-5,
151
+ pad_token_id=1,
152
+ bos_token_id=0,
153
+ eos_token_id=2,
154
+ position_embedding_type="absolute",
155
+ use_cache=True,
156
+ classifier_dropout=None,
157
+ **kwargs
158
+ ):
159
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
160
+
161
+ self.vocab_size = vocab_size
162
+ self.hidden_size = hidden_size
163
+ self.num_hidden_layers = num_hidden_layers
164
+ self.num_attention_heads = num_attention_heads
165
+ self.hidden_act = hidden_act
166
+ self.intermediate_size = intermediate_size
167
+ self.hidden_dropout_prob = hidden_dropout_prob
168
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.type_vocab_size = type_vocab_size
171
+ self.initializer_range = initializer_range
172
+ self.layer_norm_eps = layer_norm_eps
173
+ self.position_embedding_type = position_embedding_type
174
+ self.use_cache = use_cache
175
+ self.classifier_dropout = classifier_dropout
176
+
177
+ @classmethod
178
+ def get_default_config(cls, updates=None):
179
+ none_arg_types = dict(
180
+ classifier_dropout=float,
181
+ )
182
+ config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
183
+ config.tie_word_embeddings = True
184
+
185
+ if updates is not None:
186
+ config.update(ConfigDict(updates).copy_and_resolve_references())
187
+
188
+ return config
189
+
190
+ @staticmethod
191
+ def get_jax_mesh(axis_dims):
192
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
193
+
194
+ @staticmethod
195
+ def get_partition_rules():
196
+ """ Parition rules for Roberta model. """
197
+ return (
198
+ ('embeddings/(position_embeddings|token_type_embeddings)/embedding', PartitionSpec()),
199
+ ('embeddings/word_embeddings/embedding', PartitionSpec()),
200
+ ('attention/self/(key|query|value)/kernel', PartitionSpec('fsdp', 'mp')),
201
+ ('attention/self/(key|query|value)/bias', PartitionSpec()),
202
+ ('attention/output/dense/kernel', PartitionSpec('mp', 'fsdp')),
203
+ ('attention/output/dense/bias', PartitionSpec()),
204
+ ('(LayerNorm|layer_norm)/(bias|scale)', PartitionSpec()),
205
+ ('intermediate/dense/kernel', PartitionSpec('fsdp', 'mp')),
206
+ ('intermediate/dense/bias', PartitionSpec('mp')),
207
+ ('output/dense/kernel', PartitionSpec('mp', 'fsdp')),
208
+ ('output/dense/bias', PartitionSpec()),
209
+ ('lm_head/dense/kernel', PartitionSpec()),
210
+ ('lm_head/dense/bias', PartitionSpec()),
211
+ ('lm_head/decoder/kernel', PartitionSpec('fsdp', 'mp')),
212
+ ('lm_head/decoder/bias', PartitionSpec('mp')),
213
+ ('.*', PartitionSpec()),
214
+ )
215
+
216
+ @staticmethod
217
+ def get_weight_decay_exclusions():
218
+ return ('bias', 'LayerNorm/scale', 'layer_norm/scale')
219
+
220
+ @staticmethod
221
+ def rng_keys():
222
+ return ('params', 'dropout')
223
+
224
+ @staticmethod
225
+ def get_tokenizer_config(updates=None):
226
+ config = ConfigDict()
227
+ config.name = 'roberta-base'
228
+
229
+ if updates is not None:
230
+ config.update(ConfigDict(updates).copy_and_resolve_references())
231
+
232
+ return config
233
+
234
+ @classmethod
235
+ def get_tokenizer(cls, config):
236
+ config = cls.get_tokenizer_config(config)
237
+ return AutoTokenizer.from_pretrained(
238
+ config.name,
239
+ )
240
+
241
+ @staticmethod
242
+ def load_pretrained(name):
243
+ with jax.default_device(jax.devices("cpu")[0]):
244
+ params = FlaxRobertaForMaskedLM.from_pretrained(name, _do_init=False)[1]
245
+ params = freeze({'params': params})
246
+ return params
247
+
248
+ @classmethod
249
+ def load_config(cls, path):
250
+ load_type, load_path = path.split('::', 1)
251
+ if load_type == 'pickle':
252
+ return cls.from_dict(load_pickle(load_path)['roberta_config'])
253
+ elif load_type == 'huggingface':
254
+ return cls.from_pretrained(load_path)
255
+ else:
256
+ raise ValueError(f'Unsupported load config type: {load_type}')
257
+
258
+
259
+ """
260
+ The follow code is taken from
261
+ transformers/src/transformers/models/roberta/modeling_flax_roberta.py
262
+ and modified to work with EasyLM.
263
+ """
264
+
265
+
266
+ logger = logging.get_logger(__name__)
267
+
268
+ _CHECKPOINT_FOR_DOC = "roberta-base"
269
+ _CONFIG_FOR_DOC = "RobertaConfig"
270
+
271
+ remat = nn_partitioning.remat
272
+
273
+
274
+ def create_position_ids_from_input_ids(input_ids, padding_idx):
275
+ """
276
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
277
+ are ignored. This is modified from fairseq's `utils.make_positions`.
278
+ Args:
279
+ input_ids: jnp.ndarray
280
+ padding_idx: int
281
+ Returns: jnp.ndarray
282
+ """
283
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
284
+ mask = (input_ids != padding_idx).astype("i4")
285
+
286
+ if mask.ndim > 2:
287
+ mask = mask.reshape((-1, mask.shape[-1]))
288
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
289
+ incremental_indices = incremental_indices.reshape(input_ids.shape)
290
+ else:
291
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
292
+
293
+ return incremental_indices.astype("i4") + padding_idx
294
+
295
+
296
+ ROBERTA_START_DOCSTRING = r"""
297
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
298
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
299
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
300
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
301
+ general usage and behavior.
302
+ Finally, this model supports inherent JAX features such as:
303
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
304
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
305
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
306
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
307
+ Parameters:
308
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
309
+ model. Initializing with a config file does not load the weights associated with the model, only the
310
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
311
+ """
312
+
313
+ ROBERTA_INPUTS_DOCSTRING = r"""
314
+ Args:
315
+ input_ids (`numpy.ndarray` of shape `({0})`):
316
+ Indices of input sequence tokens in the vocabulary.
317
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
318
+ [`PreTrainedTokenizer.__call__`] for details.
319
+ [What are input IDs?](../glossary#input-ids)
320
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
321
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
322
+ - 1 for tokens that are **not masked**,
323
+ - 0 for tokens that are **masked**.
324
+ [What are attention masks?](../glossary#attention-mask)
325
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
326
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
327
+ 1]`:
328
+ - 0 corresponds to a *sentence A* token,
329
+ - 1 corresponds to a *sentence B* token.
330
+ [What are token type IDs?](../glossary#token-type-ids)
331
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
332
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
333
+ config.max_position_embeddings - 1]`.
334
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
335
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
336
+ - 1 indicates the head is **not masked**,
337
+ - 0 indicates the head is **masked**.
338
+ return_dict (`bool`, *optional*):
339
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
340
+ """
341
+
342
+
343
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
344
+ class FlaxRobertaEmbeddings(nn.Module):
345
+ """Construct the embeddings from word, position and token_type embeddings."""
346
+
347
+ config: RobertaConfig
348
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
349
+
350
+ def setup(self):
351
+ self.word_embeddings = nn.Embed(
352
+ self.config.vocab_size,
353
+ self.config.hidden_size,
354
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
355
+ dtype=self.dtype,
356
+ )
357
+ self.position_embeddings = nn.Embed(
358
+ self.config.max_position_embeddings,
359
+ self.config.hidden_size,
360
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
361
+ dtype=self.dtype,
362
+ )
363
+ self.token_type_embeddings = nn.Embed(
364
+ self.config.type_vocab_size,
365
+ self.config.hidden_size,
366
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
367
+ dtype=self.dtype,
368
+ )
369
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
370
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
371
+
372
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
373
+ # Embed
374
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
375
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
376
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
377
+
378
+ # Sum all embeddings
379
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
380
+
381
+ # Layer Norm
382
+ hidden_states = self.LayerNorm(hidden_states)
383
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
388
+ class FlaxRobertaSelfAttention(nn.Module):
389
+ config: RobertaConfig
390
+ causal: bool = False
391
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
392
+
393
+ def setup(self):
394
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
395
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
396
+ raise ValueError(
397
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
398
+ " : {self.config.num_attention_heads}"
399
+ )
400
+
401
+ self.query = nn.Dense(
402
+ self.config.hidden_size,
403
+ dtype=self.dtype,
404
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
405
+ )
406
+ self.key = nn.Dense(
407
+ self.config.hidden_size,
408
+ dtype=self.dtype,
409
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
410
+ )
411
+ self.value = nn.Dense(
412
+ self.config.hidden_size,
413
+ dtype=self.dtype,
414
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
415
+ )
416
+
417
+ if self.causal:
418
+ self.causal_mask = make_causal_mask(
419
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
420
+ )
421
+
422
+ def _split_heads(self, hidden_states):
423
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
424
+
425
+ def _merge_heads(self, hidden_states):
426
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
427
+
428
+ @nn.compact
429
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
430
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
431
+ """
432
+ This function takes projected key, value states from a single input token and concatenates the states to cached
433
+ states from previous steps. This function is slighly adapted from the official Flax repository:
434
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
435
+ """
436
+ # detect if we're initializing by absence of existing cache data.
437
+ is_initialized = self.has_variable("cache", "cached_key")
438
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
439
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
440
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
441
+
442
+ if is_initialized:
443
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
444
+ # update key, value caches with our new 1d spatial slices
445
+ cur_index = cache_index.value
446
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
447
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
448
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
449
+ cached_key.value = key
450
+ cached_value.value = value
451
+ num_updated_cache_vectors = query.shape[1]
452
+ cache_index.value = cache_index.value + num_updated_cache_vectors
453
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
454
+ pad_mask = jnp.broadcast_to(
455
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
456
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
457
+ )
458
+ attention_mask = combine_masks(pad_mask, attention_mask)
459
+ return key, value, attention_mask
460
+
461
+ def __call__(
462
+ self,
463
+ hidden_states,
464
+ attention_mask,
465
+ layer_head_mask,
466
+ key_value_states: Optional[jnp.array] = None,
467
+ init_cache: bool = False,
468
+ deterministic=True,
469
+ output_attentions: bool = False,
470
+ ):
471
+ # if key_value_states are provided this layer is used as a cross-attention layer
472
+ # for the decoder
473
+ is_cross_attention = key_value_states is not None
474
+ batch_size = hidden_states.shape[0]
475
+
476
+ # get query proj
477
+ query_states = self.query(hidden_states)
478
+ # get key, value proj
479
+ if is_cross_attention:
480
+ # cross_attentions
481
+ key_states = self.key(key_value_states)
482
+ value_states = self.value(key_value_states)
483
+ else:
484
+ # self_attention
485
+ key_states = self.key(hidden_states)
486
+ value_states = self.value(hidden_states)
487
+
488
+ query_states = self._split_heads(query_states)
489
+ key_states = self._split_heads(key_states)
490
+ value_states = self._split_heads(value_states)
491
+
492
+ # handle cache prepare causal attention mask
493
+ if self.causal:
494
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
495
+ if self.has_variable("cache", "cached_key"):
496
+ mask_shift = self.variables["cache"]["cache_index"]
497
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
498
+ causal_mask = lax.dynamic_slice(
499
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
500
+ )
501
+ else:
502
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
503
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
504
+
505
+ # combine masks if needed
506
+ if attention_mask is not None and self.causal:
507
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
508
+ attention_mask = combine_masks(attention_mask, causal_mask)
509
+ elif self.causal:
510
+ attention_mask = causal_mask
511
+ elif attention_mask is not None:
512
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
513
+
514
+ # During fast autoregressive decoding, we feed one position at a time,
515
+ # and cache the keys and values step by step.
516
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
517
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
518
+ key_states, value_states, query_states, attention_mask
519
+ )
520
+
521
+ # Convert the boolean attention mask to an attention bias.
522
+ if attention_mask is not None:
523
+ # attention mask in the form of attention bias
524
+ attention_bias = lax.select(
525
+ attention_mask > 0,
526
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
527
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
528
+ )
529
+ else:
530
+ attention_bias = None
531
+
532
+ dropout_rng = None
533
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
534
+ dropout_rng = self.make_rng("dropout")
535
+
536
+ attn_weights = dot_product_attention_weights(
537
+ query_states,
538
+ key_states,
539
+ bias=attention_bias,
540
+ dropout_rng=dropout_rng,
541
+ dropout_rate=self.config.attention_probs_dropout_prob,
542
+ broadcast_dropout=True,
543
+ deterministic=deterministic,
544
+ dtype=self.dtype,
545
+ precision=None,
546
+ )
547
+
548
+ # Mask heads if we want to
549
+ if layer_head_mask is not None:
550
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
551
+
552
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
553
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
554
+
555
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
556
+ return outputs
557
+
558
+
559
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
560
+ class FlaxRobertaSelfOutput(nn.Module):
561
+ config: RobertaConfig
562
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
563
+
564
+ def setup(self):
565
+ self.dense = nn.Dense(
566
+ self.config.hidden_size,
567
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
568
+ dtype=self.dtype,
569
+ )
570
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
571
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
572
+
573
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
574
+ hidden_states = self.dense(hidden_states)
575
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
576
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
577
+ return hidden_states
578
+
579
+
580
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
581
+ class FlaxRobertaAttention(nn.Module):
582
+ config: RobertaConfig
583
+ causal: bool = False
584
+ dtype: jnp.dtype = jnp.float32
585
+
586
+ def setup(self):
587
+ self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
588
+ self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
589
+
590
+ def __call__(
591
+ self,
592
+ hidden_states,
593
+ attention_mask,
594
+ layer_head_mask,
595
+ key_value_states=None,
596
+ init_cache=False,
597
+ deterministic=True,
598
+ output_attentions: bool = False,
599
+ ):
600
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
601
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
602
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
603
+ attn_outputs = self.self(
604
+ hidden_states,
605
+ attention_mask,
606
+ layer_head_mask=layer_head_mask,
607
+ key_value_states=key_value_states,
608
+ init_cache=init_cache,
609
+ deterministic=deterministic,
610
+ output_attentions=output_attentions,
611
+ )
612
+ attn_output = attn_outputs[0]
613
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
614
+
615
+ outputs = (hidden_states,)
616
+
617
+ if output_attentions:
618
+ outputs += (attn_outputs[1],)
619
+
620
+ return outputs
621
+
622
+
623
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
624
+ class FlaxRobertaIntermediate(nn.Module):
625
+ config: RobertaConfig
626
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
627
+
628
+ def setup(self):
629
+ self.dense = nn.Dense(
630
+ self.config.intermediate_size,
631
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
632
+ dtype=self.dtype,
633
+ )
634
+ self.activation = ACT2FN[self.config.hidden_act]
635
+
636
+ def __call__(self, hidden_states):
637
+ hidden_states = self.dense(hidden_states)
638
+ hidden_states = self.activation(hidden_states)
639
+ return hidden_states
640
+
641
+
642
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
643
+ class FlaxRobertaOutput(nn.Module):
644
+ config: RobertaConfig
645
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
646
+
647
+ def setup(self):
648
+ self.dense = nn.Dense(
649
+ self.config.hidden_size,
650
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
651
+ dtype=self.dtype,
652
+ )
653
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
654
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
655
+
656
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
657
+ hidden_states = self.dense(hidden_states)
658
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
659
+ hidden_states = self.LayerNorm(hidden_states + attention_output)
660
+ return hidden_states
661
+
662
+
663
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
664
+ class FlaxRobertaLayer(nn.Module):
665
+ config: RobertaConfig
666
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
667
+
668
+ def setup(self):
669
+ self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
670
+ self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
671
+ self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
672
+ if self.config.add_cross_attention:
673
+ self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
674
+
675
+ def __call__(
676
+ self,
677
+ hidden_states,
678
+ attention_mask,
679
+ layer_head_mask,
680
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
681
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
682
+ init_cache: bool = False,
683
+ deterministic: bool = True,
684
+ output_attentions: bool = False,
685
+ ):
686
+ # Self Attention
687
+ attention_outputs = self.attention(
688
+ hidden_states,
689
+ attention_mask,
690
+ layer_head_mask=layer_head_mask,
691
+ init_cache=init_cache,
692
+ deterministic=deterministic,
693
+ output_attentions=output_attentions,
694
+ )
695
+ attention_output = attention_outputs[0]
696
+
697
+ # Cross-Attention Block
698
+ if encoder_hidden_states is not None:
699
+ cross_attention_outputs = self.crossattention(
700
+ attention_output,
701
+ attention_mask=encoder_attention_mask,
702
+ layer_head_mask=layer_head_mask,
703
+ key_value_states=encoder_hidden_states,
704
+ deterministic=deterministic,
705
+ output_attentions=output_attentions,
706
+ )
707
+ attention_output = cross_attention_outputs[0]
708
+
709
+ hidden_states = self.intermediate(attention_output)
710
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
711
+
712
+ outputs = (hidden_states,)
713
+
714
+ if output_attentions:
715
+ outputs += (attention_outputs[1],)
716
+ if encoder_hidden_states is not None:
717
+ outputs += (cross_attention_outputs[1],)
718
+ return outputs
719
+
720
+
721
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
722
+ class FlaxRobertaLayerCollection(nn.Module):
723
+ config: RobertaConfig
724
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
725
+ gradient_checkpointing: bool = False
726
+
727
+ def setup(self):
728
+ if self.gradient_checkpointing:
729
+ FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
730
+ self.layers = [
731
+ FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
732
+ for i in range(self.config.num_hidden_layers)
733
+ ]
734
+ else:
735
+ self.layers = [
736
+ FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
737
+ for i in range(self.config.num_hidden_layers)
738
+ ]
739
+
740
+ def __call__(
741
+ self,
742
+ hidden_states,
743
+ attention_mask,
744
+ head_mask,
745
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
746
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
747
+ init_cache: bool = False,
748
+ deterministic: bool = True,
749
+ output_attentions: bool = False,
750
+ output_hidden_states: bool = False,
751
+ return_dict: bool = True,
752
+ ):
753
+ all_attentions = () if output_attentions else None
754
+ all_hidden_states = () if output_hidden_states else None
755
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
756
+
757
+ # Check if head_mask has a correct number of layers specified if desired
758
+ if head_mask is not None:
759
+ if head_mask.shape[0] != (len(self.layers)):
760
+ raise ValueError(
761
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
762
+ f" {head_mask.shape[0]}."
763
+ )
764
+
765
+ for i, layer in enumerate(self.layers):
766
+ if output_hidden_states:
767
+ all_hidden_states += (hidden_states,)
768
+
769
+ layer_outputs = layer(
770
+ hidden_states,
771
+ attention_mask,
772
+ head_mask[i] if head_mask is not None else None,
773
+ encoder_hidden_states,
774
+ encoder_attention_mask,
775
+ init_cache,
776
+ deterministic,
777
+ output_attentions,
778
+ )
779
+
780
+ hidden_states = layer_outputs[0]
781
+
782
+ if output_attentions:
783
+ all_attentions += (layer_outputs[1],)
784
+
785
+ if encoder_hidden_states is not None:
786
+ all_cross_attentions += (layer_outputs[2],)
787
+
788
+ if output_hidden_states:
789
+ all_hidden_states += (hidden_states,)
790
+
791
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
792
+
793
+ if not return_dict:
794
+ return tuple(v for v in outputs if v is not None)
795
+
796
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
797
+ last_hidden_state=hidden_states,
798
+ hidden_states=all_hidden_states,
799
+ attentions=all_attentions,
800
+ cross_attentions=all_cross_attentions,
801
+ )
802
+
803
+
804
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
805
+ class FlaxRobertaEncoder(nn.Module):
806
+ config: RobertaConfig
807
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
808
+ gradient_checkpointing: bool = False
809
+
810
+ def setup(self):
811
+ self.layer = FlaxRobertaLayerCollection(
812
+ self.config,
813
+ dtype=self.dtype,
814
+ gradient_checkpointing=self.gradient_checkpointing,
815
+ )
816
+
817
+ def __call__(
818
+ self,
819
+ hidden_states,
820
+ attention_mask,
821
+ head_mask,
822
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
823
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
824
+ init_cache: bool = False,
825
+ deterministic: bool = True,
826
+ output_attentions: bool = False,
827
+ output_hidden_states: bool = False,
828
+ return_dict: bool = True,
829
+ ):
830
+ return self.layer(
831
+ hidden_states,
832
+ attention_mask,
833
+ head_mask=head_mask,
834
+ encoder_hidden_states=encoder_hidden_states,
835
+ encoder_attention_mask=encoder_attention_mask,
836
+ init_cache=init_cache,
837
+ deterministic=deterministic,
838
+ output_attentions=output_attentions,
839
+ output_hidden_states=output_hidden_states,
840
+ return_dict=return_dict,
841
+ )
842
+
843
+
844
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
845
+ class FlaxRobertaPooler(nn.Module):
846
+ config: RobertaConfig
847
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
848
+
849
+ def setup(self):
850
+ self.dense = nn.Dense(
851
+ self.config.hidden_size,
852
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
853
+ dtype=self.dtype,
854
+ )
855
+
856
+ def __call__(self, hidden_states):
857
+ cls_hidden_state = hidden_states[:, 0]
858
+ cls_hidden_state = self.dense(cls_hidden_state)
859
+ return nn.tanh(cls_hidden_state)
860
+
861
+
862
+ class FlaxRobertaLMHead(nn.Module):
863
+ config: RobertaConfig
864
+ dtype: jnp.dtype = jnp.float32
865
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
866
+
867
+ def setup(self):
868
+ self.dense = nn.Dense(
869
+ self.config.hidden_size,
870
+ dtype=self.dtype,
871
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
872
+ )
873
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
874
+ self.decoder = nn.Dense(
875
+ self.config.vocab_size,
876
+ dtype=self.dtype,
877
+ use_bias=False,
878
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
879
+ )
880
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
881
+
882
+ def __call__(self, hidden_states, shared_embedding=None):
883
+ hidden_states = self.dense(hidden_states)
884
+ hidden_states = ACT2FN["gelu"](hidden_states)
885
+ hidden_states = self.layer_norm(hidden_states)
886
+
887
+ if shared_embedding is not None:
888
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
889
+ else:
890
+ hidden_states = self.decoder(hidden_states)
891
+
892
+ bias = jnp.asarray(self.bias, self.dtype)
893
+ hidden_states += bias
894
+ return hidden_states
895
+
896
+
897
+ class FlaxRobertaClassificationHead(nn.Module):
898
+ config: RobertaConfig
899
+ dtype: jnp.dtype = jnp.float32
900
+
901
+ def setup(self):
902
+ self.dense = nn.Dense(
903
+ self.config.hidden_size,
904
+ dtype=self.dtype,
905
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
906
+ )
907
+ classifier_dropout = (
908
+ self.config.classifier_dropout
909
+ if self.config.classifier_dropout is not None
910
+ else self.config.hidden_dropout_prob
911
+ )
912
+ self.dropout = nn.Dropout(rate=classifier_dropout)
913
+ self.out_proj = nn.Dense(
914
+ self.config.num_labels,
915
+ dtype=self.dtype,
916
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
917
+ )
918
+
919
+ def __call__(self, hidden_states, deterministic=True):
920
+ hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
921
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
922
+ hidden_states = self.dense(hidden_states)
923
+ hidden_states = nn.tanh(hidden_states)
924
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
925
+ hidden_states = self.out_proj(hidden_states)
926
+ return hidden_states
927
+
928
+
929
+ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
930
+ """
931
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
932
+ models.
933
+ """
934
+
935
+ config_class = RobertaConfig
936
+ base_model_prefix = "roberta"
937
+
938
+ module_class: nn.Module = None
939
+
940
+ def __init__(
941
+ self,
942
+ config: RobertaConfig,
943
+ input_shape: Tuple = (1, 1),
944
+ seed: int = 0,
945
+ dtype: jnp.dtype = jnp.float32,
946
+ _do_init: bool = True,
947
+ gradient_checkpointing: bool = False,
948
+ **kwargs,
949
+ ):
950
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
951
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
952
+
953
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
954
+ def enable_gradient_checkpointing(self):
955
+ self._module = self.module_class(
956
+ config=self.config,
957
+ dtype=self.dtype,
958
+ gradient_checkpointing=True,
959
+ )
960
+
961
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
962
+ # init input tensors
963
+ input_ids = jnp.zeros(input_shape, dtype="i4")
964
+ token_type_ids = jnp.ones_like(input_ids)
965
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
966
+ attention_mask = jnp.ones_like(input_ids)
967
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
968
+
969
+ params_rng, dropout_rng = jax.random.split(rng)
970
+ rngs = {"params": params_rng, "dropout": dropout_rng}
971
+
972
+ if self.config.add_cross_attention:
973
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
974
+ encoder_attention_mask = attention_mask
975
+ module_init_outputs = self.module.init(
976
+ rngs,
977
+ input_ids,
978
+ attention_mask,
979
+ token_type_ids,
980
+ position_ids,
981
+ head_mask,
982
+ encoder_hidden_states,
983
+ encoder_attention_mask,
984
+ return_dict=False,
985
+ )
986
+ else:
987
+ module_init_outputs = self.module.init(
988
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
989
+ )
990
+
991
+ random_params = module_init_outputs["params"]
992
+
993
+ if params is not None:
994
+ random_params = flatten_dict(unfreeze(random_params))
995
+ params = flatten_dict(unfreeze(params))
996
+ for missing_key in self._missing_keys:
997
+ params[missing_key] = random_params[missing_key]
998
+ self._missing_keys = set()
999
+ return freeze(unflatten_dict(params))
1000
+ else:
1001
+ return random_params
1002
+
1003
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
1004
+ def init_cache(self, batch_size, max_length):
1005
+ r"""
1006
+ Args:
1007
+ batch_size (`int`):
1008
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
1009
+ max_length (`int`):
1010
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
1011
+ cache.
1012
+ """
1013
+ # init input variables to retrieve cache
1014
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
1015
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
1016
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
1017
+
1018
+ init_variables = self.module.init(
1019
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
1020
+ )
1021
+ return unfreeze(init_variables["cache"])
1022
+
1023
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1024
+ def __call__(
1025
+ self,
1026
+ input_ids,
1027
+ attention_mask=None,
1028
+ token_type_ids=None,
1029
+ position_ids=None,
1030
+ head_mask=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ params: dict = None,
1034
+ dropout_rng: jax.random.PRNGKey = None,
1035
+ train: bool = False,
1036
+ output_attentions: Optional[bool] = None,
1037
+ output_hidden_states: Optional[bool] = None,
1038
+ return_dict: Optional[bool] = None,
1039
+ past_key_values: dict = None,
1040
+ ):
1041
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1042
+ output_hidden_states = (
1043
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1044
+ )
1045
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1046
+
1047
+ # init input tensors if not passed
1048
+ if token_type_ids is None:
1049
+ token_type_ids = jnp.zeros_like(input_ids)
1050
+
1051
+ if position_ids is None:
1052
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
1053
+
1054
+ if attention_mask is None:
1055
+ attention_mask = jnp.ones_like(input_ids)
1056
+
1057
+ if head_mask is None:
1058
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
1059
+
1060
+ # Handle any PRNG if needed
1061
+ rngs = {}
1062
+ if dropout_rng is not None:
1063
+ rngs["dropout"] = dropout_rng
1064
+
1065
+ inputs = {"params": params or self.params}
1066
+
1067
+ if self.config.add_cross_attention:
1068
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
1069
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
1070
+ # changed by FlaxRobertaAttention module
1071
+ if past_key_values:
1072
+ inputs["cache"] = past_key_values
1073
+ mutable = ["cache"]
1074
+ else:
1075
+ mutable = False
1076
+
1077
+ outputs = self.module.apply(
1078
+ inputs,
1079
+ jnp.array(input_ids, dtype="i4"),
1080
+ jnp.array(attention_mask, dtype="i4"),
1081
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
1082
+ position_ids=jnp.array(position_ids, dtype="i4"),
1083
+ head_mask=jnp.array(head_mask, dtype="i4"),
1084
+ encoder_hidden_states=encoder_hidden_states,
1085
+ encoder_attention_mask=encoder_attention_mask,
1086
+ deterministic=not train,
1087
+ output_attentions=output_attentions,
1088
+ output_hidden_states=output_hidden_states,
1089
+ return_dict=return_dict,
1090
+ rngs=rngs,
1091
+ mutable=mutable,
1092
+ )
1093
+
1094
+ # add updated cache to model output
1095
+ if past_key_values is not None and return_dict:
1096
+ outputs, past_key_values = outputs
1097
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
1098
+ return outputs
1099
+ elif past_key_values is not None and not return_dict:
1100
+ outputs, past_key_values = outputs
1101
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
1102
+
1103
+ else:
1104
+ outputs = self.module.apply(
1105
+ inputs,
1106
+ jnp.array(input_ids, dtype="i4"),
1107
+ jnp.array(attention_mask, dtype="i4"),
1108
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
1109
+ position_ids=jnp.array(position_ids, dtype="i4"),
1110
+ head_mask=jnp.array(head_mask, dtype="i4"),
1111
+ deterministic=not train,
1112
+ output_attentions=output_attentions,
1113
+ output_hidden_states=output_hidden_states,
1114
+ return_dict=return_dict,
1115
+ rngs=rngs,
1116
+ )
1117
+
1118
+ return outputs
1119
+
1120
+
1121
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
1122
+ class FlaxRobertaModule(nn.Module):
1123
+ config: RobertaConfig
1124
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1125
+ add_pooling_layer: bool = True
1126
+ gradient_checkpointing: bool = False
1127
+
1128
+ def setup(self):
1129
+ self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
1130
+ self.encoder = FlaxRobertaEncoder(
1131
+ self.config,
1132
+ dtype=self.dtype,
1133
+ gradient_checkpointing=self.gradient_checkpointing,
1134
+ )
1135
+ self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
1136
+
1137
+ def __call__(
1138
+ self,
1139
+ input_ids,
1140
+ attention_mask,
1141
+ token_type_ids: Optional[jnp.ndarray] = None,
1142
+ position_ids: Optional[jnp.ndarray] = None,
1143
+ head_mask: Optional[jnp.ndarray] = None,
1144
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1145
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1146
+ init_cache: bool = False,
1147
+ deterministic: bool = True,
1148
+ output_attentions: bool = False,
1149
+ output_hidden_states: bool = False,
1150
+ return_dict: bool = True,
1151
+ ):
1152
+ # make sure `token_type_ids` is correctly initialized when not passed
1153
+ if token_type_ids is None:
1154
+ token_type_ids = jnp.zeros_like(input_ids)
1155
+
1156
+ # make sure `position_ids` is correctly initialized when not passed
1157
+ if position_ids is None:
1158
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
1159
+
1160
+ hidden_states = self.embeddings(
1161
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
1162
+ )
1163
+ outputs = self.encoder(
1164
+ hidden_states,
1165
+ attention_mask,
1166
+ head_mask=head_mask,
1167
+ deterministic=deterministic,
1168
+ encoder_hidden_states=encoder_hidden_states,
1169
+ encoder_attention_mask=encoder_attention_mask,
1170
+ init_cache=init_cache,
1171
+ output_attentions=output_attentions,
1172
+ output_hidden_states=output_hidden_states,
1173
+ return_dict=return_dict,
1174
+ )
1175
+ hidden_states = outputs[0]
1176
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
1177
+
1178
+ if not return_dict:
1179
+ # if pooled is None, don't return it
1180
+ if pooled is None:
1181
+ return (hidden_states,) + outputs[1:]
1182
+ return (hidden_states, pooled) + outputs[1:]
1183
+
1184
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
1185
+ last_hidden_state=hidden_states,
1186
+ pooler_output=pooled,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ cross_attentions=outputs.cross_attentions,
1190
+ )
1191
+
1192
+
1193
+ @add_start_docstrings(
1194
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
1195
+ ROBERTA_START_DOCSTRING,
1196
+ )
1197
+ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
1198
+ module_class = FlaxRobertaModule
1199
+
1200
+
1201
+ append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
1202
+
1203
+
1204
+ class FlaxRobertaForMaskedLMModule(nn.Module):
1205
+ config: RobertaConfig
1206
+ dtype: jnp.dtype = jnp.float32
1207
+ gradient_checkpointing: bool = False
1208
+
1209
+ def setup(self):
1210
+ self.roberta = FlaxRobertaModule(
1211
+ config=self.config,
1212
+ add_pooling_layer=False,
1213
+ dtype=self.dtype,
1214
+ gradient_checkpointing=self.gradient_checkpointing,
1215
+ )
1216
+ self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
1217
+
1218
+ def __call__(
1219
+ self,
1220
+ input_ids,
1221
+ attention_mask,
1222
+ token_type_ids,
1223
+ position_ids,
1224
+ head_mask,
1225
+ deterministic: bool = True,
1226
+ output_attentions: bool = False,
1227
+ output_hidden_states: bool = False,
1228
+ return_dict: bool = True,
1229
+ ):
1230
+ # Model
1231
+ outputs = self.roberta(
1232
+ input_ids,
1233
+ attention_mask,
1234
+ token_type_ids,
1235
+ position_ids,
1236
+ head_mask,
1237
+ deterministic=deterministic,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+
1243
+ hidden_states = outputs[0]
1244
+ if self.config.tie_word_embeddings:
1245
+ shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1246
+ else:
1247
+ shared_embedding = None
1248
+
1249
+ # Compute the prediction scores
1250
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1251
+
1252
+ if not return_dict:
1253
+ return (logits,) + outputs[1:]
1254
+
1255
+ return FlaxMaskedLMOutput(
1256
+ logits=logits,
1257
+ hidden_states=outputs.hidden_states,
1258
+ attentions=outputs.attentions,
1259
+ )
1260
+
1261
+
1262
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
1263
+ class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
1264
+ module_class = FlaxRobertaForMaskedLMModule
1265
+
1266
+
1267
+ append_call_sample_docstring(
1268
+ FlaxRobertaForMaskedLM,
1269
+ _CHECKPOINT_FOR_DOC,
1270
+ FlaxBaseModelOutputWithPooling,
1271
+ _CONFIG_FOR_DOC,
1272
+ mask="<mask>",
1273
+ )
1274
+
1275
+
1276
+ class FlaxRobertaForSequenceClassificationModule(nn.Module):
1277
+ config: RobertaConfig
1278
+ dtype: jnp.dtype = jnp.float32
1279
+ gradient_checkpointing: bool = False
1280
+
1281
+ def setup(self):
1282
+ self.roberta = FlaxRobertaModule(
1283
+ config=self.config,
1284
+ dtype=self.dtype,
1285
+ add_pooling_layer=False,
1286
+ gradient_checkpointing=self.gradient_checkpointing,
1287
+ )
1288
+ self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
1289
+
1290
+ def __call__(
1291
+ self,
1292
+ input_ids,
1293
+ attention_mask,
1294
+ token_type_ids,
1295
+ position_ids,
1296
+ head_mask,
1297
+ deterministic: bool = True,
1298
+ output_attentions: bool = False,
1299
+ output_hidden_states: bool = False,
1300
+ return_dict: bool = True,
1301
+ ):
1302
+ # Model
1303
+ outputs = self.roberta(
1304
+ input_ids,
1305
+ attention_mask,
1306
+ token_type_ids,
1307
+ position_ids,
1308
+ head_mask,
1309
+ deterministic=deterministic,
1310
+ output_attentions=output_attentions,
1311
+ output_hidden_states=output_hidden_states,
1312
+ return_dict=return_dict,
1313
+ )
1314
+
1315
+ sequence_output = outputs[0]
1316
+ logits = self.classifier(sequence_output, deterministic=deterministic)
1317
+
1318
+ if not return_dict:
1319
+ return (logits,) + outputs[1:]
1320
+
1321
+ return FlaxSequenceClassifierOutput(
1322
+ logits=logits,
1323
+ hidden_states=outputs.hidden_states,
1324
+ attentions=outputs.attentions,
1325
+ )
1326
+
1327
+
1328
+ @add_start_docstrings(
1329
+ """
1330
+ Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1331
+ pooled output) e.g. for GLUE tasks.
1332
+ """,
1333
+ ROBERTA_START_DOCSTRING,
1334
+ )
1335
+ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
1336
+ module_class = FlaxRobertaForSequenceClassificationModule
1337
+
1338
+
1339
+ append_call_sample_docstring(
1340
+ FlaxRobertaForSequenceClassification,
1341
+ _CHECKPOINT_FOR_DOC,
1342
+ FlaxSequenceClassifierOutput,
1343
+ _CONFIG_FOR_DOC,
1344
+ )
1345
+
1346
+
1347
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
1348
+ class FlaxRobertaForMultipleChoiceModule(nn.Module):
1349
+ config: RobertaConfig
1350
+ dtype: jnp.dtype = jnp.float32
1351
+ gradient_checkpointing: bool = False
1352
+
1353
+ def setup(self):
1354
+ self.roberta = FlaxRobertaModule(
1355
+ config=self.config,
1356
+ dtype=self.dtype,
1357
+ gradient_checkpointing=self.gradient_checkpointing,
1358
+ )
1359
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
1360
+ self.classifier = nn.Dense(1, dtype=self.dtype)
1361
+
1362
+ def __call__(
1363
+ self,
1364
+ input_ids,
1365
+ attention_mask,
1366
+ token_type_ids,
1367
+ position_ids,
1368
+ head_mask,
1369
+ deterministic: bool = True,
1370
+ output_attentions: bool = False,
1371
+ output_hidden_states: bool = False,
1372
+ return_dict: bool = True,
1373
+ ):
1374
+ num_choices = input_ids.shape[1]
1375
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
1376
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
1377
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
1378
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
1379
+
1380
+ # Model
1381
+ outputs = self.roberta(
1382
+ input_ids,
1383
+ attention_mask,
1384
+ token_type_ids,
1385
+ position_ids,
1386
+ head_mask,
1387
+ deterministic=deterministic,
1388
+ output_attentions=output_attentions,
1389
+ output_hidden_states=output_hidden_states,
1390
+ return_dict=return_dict,
1391
+ )
1392
+
1393
+ pooled_output = outputs[1]
1394
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
1395
+ logits = self.classifier(pooled_output)
1396
+
1397
+ reshaped_logits = logits.reshape(-1, num_choices)
1398
+
1399
+ if not return_dict:
1400
+ return (reshaped_logits,) + outputs[2:]
1401
+
1402
+ return FlaxMultipleChoiceModelOutput(
1403
+ logits=reshaped_logits,
1404
+ hidden_states=outputs.hidden_states,
1405
+ attentions=outputs.attentions,
1406
+ )
1407
+
1408
+
1409
+ @add_start_docstrings(
1410
+ """
1411
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1412
+ softmax) e.g. for RocStories/SWAG tasks.
1413
+ """,
1414
+ ROBERTA_START_DOCSTRING,
1415
+ )
1416
+ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
1417
+ module_class = FlaxRobertaForMultipleChoiceModule
1418
+
1419
+
1420
+ overwrite_call_docstring(
1421
+ FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1422
+ )
1423
+ append_call_sample_docstring(
1424
+ FlaxRobertaForMultipleChoice,
1425
+ _CHECKPOINT_FOR_DOC,
1426
+ FlaxMultipleChoiceModelOutput,
1427
+ _CONFIG_FOR_DOC,
1428
+ )
1429
+
1430
+
1431
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
1432
+ class FlaxRobertaForTokenClassificationModule(nn.Module):
1433
+ config: RobertaConfig
1434
+ dtype: jnp.dtype = jnp.float32
1435
+ gradient_checkpointing: bool = False
1436
+
1437
+ def setup(self):
1438
+ self.roberta = FlaxRobertaModule(
1439
+ config=self.config,
1440
+ dtype=self.dtype,
1441
+ add_pooling_layer=False,
1442
+ gradient_checkpointing=self.gradient_checkpointing,
1443
+ )
1444
+ classifier_dropout = (
1445
+ self.config.classifier_dropout
1446
+ if self.config.classifier_dropout is not None
1447
+ else self.config.hidden_dropout_prob
1448
+ )
1449
+ self.dropout = nn.Dropout(rate=classifier_dropout)
1450
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1451
+
1452
+ def __call__(
1453
+ self,
1454
+ input_ids,
1455
+ attention_mask,
1456
+ token_type_ids,
1457
+ position_ids,
1458
+ head_mask,
1459
+ deterministic: bool = True,
1460
+ output_attentions: bool = False,
1461
+ output_hidden_states: bool = False,
1462
+ return_dict: bool = True,
1463
+ ):
1464
+ # Model
1465
+ outputs = self.roberta(
1466
+ input_ids,
1467
+ attention_mask,
1468
+ token_type_ids,
1469
+ position_ids,
1470
+ head_mask,
1471
+ deterministic=deterministic,
1472
+ output_attentions=output_attentions,
1473
+ output_hidden_states=output_hidden_states,
1474
+ return_dict=return_dict,
1475
+ )
1476
+
1477
+ hidden_states = outputs[0]
1478
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1479
+ logits = self.classifier(hidden_states)
1480
+
1481
+ if not return_dict:
1482
+ return (logits,) + outputs[1:]
1483
+
1484
+ return FlaxTokenClassifierOutput(
1485
+ logits=logits,
1486
+ hidden_states=outputs.hidden_states,
1487
+ attentions=outputs.attentions,
1488
+ )
1489
+
1490
+
1491
+ @add_start_docstrings(
1492
+ """
1493
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1494
+ Named-Entity-Recognition (NER) tasks.
1495
+ """,
1496
+ ROBERTA_START_DOCSTRING,
1497
+ )
1498
+ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
1499
+ module_class = FlaxRobertaForTokenClassificationModule
1500
+
1501
+
1502
+ append_call_sample_docstring(
1503
+ FlaxRobertaForTokenClassification,
1504
+ _CHECKPOINT_FOR_DOC,
1505
+ FlaxTokenClassifierOutput,
1506
+ _CONFIG_FOR_DOC,
1507
+ )
1508
+
1509
+
1510
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
1511
+ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
1512
+ config: RobertaConfig
1513
+ dtype: jnp.dtype = jnp.float32
1514
+ gradient_checkpointing: bool = False
1515
+
1516
+ def setup(self):
1517
+ self.roberta = FlaxRobertaModule(
1518
+ config=self.config,
1519
+ dtype=self.dtype,
1520
+ add_pooling_layer=False,
1521
+ gradient_checkpointing=self.gradient_checkpointing,
1522
+ )
1523
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1524
+
1525
+ def __call__(
1526
+ self,
1527
+ input_ids,
1528
+ attention_mask,
1529
+ token_type_ids,
1530
+ position_ids,
1531
+ head_mask,
1532
+ deterministic: bool = True,
1533
+ output_attentions: bool = False,
1534
+ output_hidden_states: bool = False,
1535
+ return_dict: bool = True,
1536
+ ):
1537
+ # Model
1538
+ outputs = self.roberta(
1539
+ input_ids,
1540
+ attention_mask,
1541
+ token_type_ids,
1542
+ position_ids,
1543
+ head_mask,
1544
+ deterministic=deterministic,
1545
+ output_attentions=output_attentions,
1546
+ output_hidden_states=output_hidden_states,
1547
+ return_dict=return_dict,
1548
+ )
1549
+
1550
+ hidden_states = outputs[0]
1551
+
1552
+ logits = self.qa_outputs(hidden_states)
1553
+ start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
1554
+ start_logits = start_logits.squeeze(-1)
1555
+ end_logits = end_logits.squeeze(-1)
1556
+
1557
+ if not return_dict:
1558
+ return (start_logits, end_logits) + outputs[1:]
1559
+
1560
+ return FlaxQuestionAnsweringModelOutput(
1561
+ start_logits=start_logits,
1562
+ end_logits=end_logits,
1563
+ hidden_states=outputs.hidden_states,
1564
+ attentions=outputs.attentions,
1565
+ )
1566
+
1567
+
1568
+ @add_start_docstrings(
1569
+ """
1570
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1571
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1572
+ """,
1573
+ ROBERTA_START_DOCSTRING,
1574
+ )
1575
+ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
1576
+ module_class = FlaxRobertaForQuestionAnsweringModule
1577
+
1578
+
1579
+ append_call_sample_docstring(
1580
+ FlaxRobertaForQuestionAnswering,
1581
+ _CHECKPOINT_FOR_DOC,
1582
+ FlaxQuestionAnsweringModelOutput,
1583
+ _CONFIG_FOR_DOC,
1584
+ )
1585
+
1586
+
1587
+ class FlaxRobertaForCausalLMModule(nn.Module):
1588
+ config: RobertaConfig
1589
+ dtype: jnp.dtype = jnp.float32
1590
+ gradient_checkpointing: bool = False
1591
+
1592
+ def setup(self):
1593
+ self.roberta = FlaxRobertaModule(
1594
+ config=self.config,
1595
+ add_pooling_layer=False,
1596
+ dtype=self.dtype,
1597
+ gradient_checkpointing=self.gradient_checkpointing,
1598
+ )
1599
+ self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
1600
+
1601
+ def __call__(
1602
+ self,
1603
+ input_ids,
1604
+ attention_mask,
1605
+ position_ids,
1606
+ token_type_ids: Optional[jnp.ndarray] = None,
1607
+ head_mask: Optional[jnp.ndarray] = None,
1608
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1609
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1610
+ init_cache: bool = False,
1611
+ deterministic: bool = True,
1612
+ output_attentions: bool = False,
1613
+ output_hidden_states: bool = False,
1614
+ return_dict: bool = True,
1615
+ ):
1616
+ # Model
1617
+ outputs = self.roberta(
1618
+ input_ids,
1619
+ attention_mask,
1620
+ token_type_ids,
1621
+ position_ids,
1622
+ head_mask,
1623
+ encoder_hidden_states=encoder_hidden_states,
1624
+ encoder_attention_mask=encoder_attention_mask,
1625
+ init_cache=init_cache,
1626
+ deterministic=deterministic,
1627
+ output_attentions=output_attentions,
1628
+ output_hidden_states=output_hidden_states,
1629
+ return_dict=return_dict,
1630
+ )
1631
+
1632
+ hidden_states = outputs[0]
1633
+ if self.config.tie_word_embeddings:
1634
+ shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1635
+ else:
1636
+ shared_embedding = None
1637
+
1638
+ # Compute the prediction scores
1639
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1640
+
1641
+ if not return_dict:
1642
+ return (logits,) + outputs[1:]
1643
+
1644
+ return FlaxCausalLMOutputWithCrossAttentions(
1645
+ logits=logits,
1646
+ hidden_states=outputs.hidden_states,
1647
+ attentions=outputs.attentions,
1648
+ cross_attentions=outputs.cross_attentions,
1649
+ )
1650
+
1651
+
1652
+ @add_start_docstrings(
1653
+ """
1654
+ Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
1655
+ autoregressive tasks.
1656
+ """,
1657
+ ROBERTA_START_DOCSTRING,
1658
+ )
1659
+ class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
1660
+ module_class = FlaxRobertaForCausalLMModule
1661
+
1662
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1663
+ # initializing the cache
1664
+ batch_size, seq_length = input_ids.shape
1665
+
1666
+ past_key_values = self.init_cache(batch_size, max_length)
1667
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1668
+ # But since the decoder uses a causal mask, those positions are masked anyway.
1669
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
1670
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1671
+ if attention_mask is not None:
1672
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1673
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1674
+ else:
1675
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1676
+
1677
+ return {
1678
+ "past_key_values": past_key_values,
1679
+ "attention_mask": extended_attention_mask,
1680
+ "position_ids": position_ids,
1681
+ }
1682
+
1683
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1684
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1685
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1686
+ return model_kwargs
1687
+
1688
+
1689
+ append_call_sample_docstring(
1690
+ FlaxRobertaForCausalLM,
1691
+ _CHECKPOINT_FOR_DOC,
1692
+ FlaxCausalLMOutputWithCrossAttentions,
1693
+ _CONFIG_FOR_DOC,
1694
+ )
EasyLM/models/roberta/roberta_train.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ from functools import partial
4
+ import re
5
+
6
+ from tqdm import tqdm, trange
7
+ import numpy as np
8
+ import mlxu
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax.experimental.pjit import pjit, with_sharding_constraint
13
+ from jax.sharding import PartitionSpec as PS
14
+ from flax.training.train_state import TrainState
15
+
16
+ from EasyLM.data import DatasetFactory
17
+ from EasyLM.checkpoint import StreamingCheckpointer
18
+ from EasyLM.optimizers import OptimizerFactory
19
+ from EasyLM.jax_utils import (
20
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
21
+ cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
22
+ set_random_seed, average_metrics, get_weight_decay_mask,
23
+ make_shard_and_gather_fns, tree_apply
24
+ )
25
+ from EasyLM.models.roberta.roberta_model import (
26
+ RobertaConfig, FlaxRobertaForMaskedLMModule
27
+ )
28
+
29
+
30
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
+ seed=42,
32
+ mesh_dim='-1,1,1',
33
+ dtype='fp32',
34
+ mask_token_probability=0.15,
35
+ total_steps=10000,
36
+ load_roberta_config='',
37
+ update_roberta_config='',
38
+ load_checkpoint='',
39
+ load_dataset_state='',
40
+ log_freq=50,
41
+ save_model_freq=0,
42
+ save_milestone_freq=0,
43
+ eval_steps=0,
44
+ tokenizer=RobertaConfig.get_tokenizer_config(),
45
+ train_dataset=DatasetFactory.get_default_config(),
46
+ eval_dataset=DatasetFactory.get_default_config(),
47
+ optimizer=OptimizerFactory.get_default_config(),
48
+ checkpointer=StreamingCheckpointer.get_default_config(),
49
+ roberta=RobertaConfig.get_default_config(),
50
+ logger=mlxu.WandBLogger.get_default_config(),
51
+ log_all_worker=False,
52
+ jax_distributed=JaxDistributedConfig.get_default_config(),
53
+ )
54
+
55
+
56
+ def main(argv):
57
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
58
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
59
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
60
+ logger = mlxu.WandBLogger(
61
+ config=FLAGS.logger,
62
+ variant=variant,
63
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
64
+ )
65
+ set_random_seed(FLAGS.seed)
66
+
67
+ tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
68
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
69
+ if FLAGS.load_dataset_state != '':
70
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
71
+
72
+ if FLAGS.eval_steps > 0:
73
+ eval_dataset = DatasetFactory.load_dataset(
74
+ FLAGS.eval_dataset, dataset.tokenizer
75
+ )
76
+ eval_iterator = iter(eval_dataset)
77
+
78
+ seq_length = dataset.seq_length
79
+
80
+ if FLAGS.load_roberta_config != '':
81
+ roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
82
+ else:
83
+ roberta_config = RobertaConfig(**FLAGS.roberta)
84
+
85
+ if FLAGS.update_roberta_config != '':
86
+ roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
87
+
88
+ roberta_config.update(dict(
89
+ bos_token_id=dataset.tokenizer.bos_token_id,
90
+ eos_token_id=dataset.tokenizer.eos_token_id,
91
+ pad_token_id=dataset.tokenizer.pad_token_id,
92
+ vocab_size=dataset.vocab_size,
93
+ ))
94
+
95
+ model = FlaxRobertaForMaskedLMModule(
96
+ roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
97
+ )
98
+
99
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
100
+ FLAGS.optimizer,
101
+ get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
102
+ )
103
+
104
+ def create_trainstate_from_params(params):
105
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
106
+
107
+ def init_fn(rng):
108
+ rng_generator = JaxRNG(rng)
109
+ params = model.init(
110
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
111
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
112
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
113
+ token_type_ids=None,
114
+ head_mask=None,
115
+ rngs=rng_generator(roberta_config.rng_keys()),
116
+ )
117
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
118
+
119
+ def train_step(train_state, rng, batch):
120
+ rng_generator = JaxRNG(rng)
121
+ tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
122
+ def loss_and_accuracy(params):
123
+ altered_tokens = jax.random.uniform(
124
+ rng_generator(), shape=tokens.shape
125
+ ) < FLAGS.mask_token_probability
126
+ random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
127
+ altered_by_mask = altered_tokens & (random_uniform < 0.8)
128
+ altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
129
+ inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
130
+ random_tokens = jax.random.randint(
131
+ rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
132
+ )
133
+ inputs = jnp.where(altered_by_random, random_tokens, inputs)
134
+ logits = model.apply(
135
+ params, inputs,
136
+ attention_mask=jnp.ones_like(inputs),
137
+ token_type_ids=None,
138
+ position_ids=None,
139
+ head_mask=None,
140
+ deterministic=False,
141
+ rngs=rng_generator(roberta_config.rng_keys()),
142
+ ).logits
143
+ return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
144
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
145
+ (loss, accuracy), grads = grad_fn(train_state.params)
146
+ train_state = train_state.apply_gradients(grads=grads)
147
+ metrics = dict(
148
+ loss=loss,
149
+ accuracy=accuracy,
150
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
151
+ gradient_norm=global_norm(grads),
152
+ param_norm=global_norm(train_state.params),
153
+ )
154
+ return train_state, rng_generator(), metrics
155
+
156
+ def eval_step(train_state, rng, batch):
157
+ rng_generator = JaxRNG(rng)
158
+ tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
159
+ altered_tokens = jax.random.uniform(
160
+ rng_generator(), shape=tokens.shape
161
+ ) < FLAGS.mask_token_probability
162
+ random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
163
+ altered_by_mask = altered_tokens & (random_uniform < 0.8)
164
+ altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
165
+ inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
166
+ random_tokens = jax.random.randint(
167
+ rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
168
+ )
169
+ inputs = jnp.where(altered_by_random, random_tokens, inputs)
170
+ logits = model.apply(
171
+ train_state.params, inputs,
172
+ attention_mask=jnp.ones_like(inputs),
173
+ token_type_ids=None,
174
+ position_ids=None,
175
+ head_mask=None,
176
+ deterministic=False,
177
+ rngs=rng_generator(roberta_config.rng_keys()),
178
+ ).logits
179
+ loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
180
+ metrics = dict(
181
+ eval_loss=loss,
182
+ eval_accuracy=accuracy,
183
+ )
184
+ return rng_generator(), metrics
185
+
186
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
187
+ train_state_partition = match_partition_rules(
188
+ RobertaConfig.get_partition_rules(), train_state_shapes
189
+ )
190
+
191
+ shard_fns, gather_fns = make_shard_and_gather_fns(
192
+ train_state_partition, train_state_shapes
193
+ )
194
+ checkpointer = StreamingCheckpointer(
195
+ FLAGS.checkpointer, logger.output_dir,
196
+ enable=jax.process_index() == 0
197
+ )
198
+
199
+ sharded_init_fn = pjit(
200
+ init_fn,
201
+ in_shardings=PS(),
202
+ out_shardings=train_state_partition
203
+ )
204
+
205
+ sharded_create_trainstate_from_params = pjit(
206
+ create_trainstate_from_params,
207
+ in_shardings=(train_state_partition.params, ),
208
+ out_shardings=train_state_partition,
209
+ donate_argnums=(0, ),
210
+ )
211
+
212
+ sharded_train_step = pjit(
213
+ train_step,
214
+ in_shardings=(train_state_partition, PS(), PS()),
215
+ out_shardings=(train_state_partition, PS(), PS()),
216
+ donate_argnums=(0, 1),
217
+ )
218
+
219
+ sharded_eval_step = pjit(
220
+ eval_step,
221
+ in_shardings=(train_state_partition, PS(), PS()),
222
+ out_shardings=(PS(), PS()),
223
+ donate_argnums=(1,),
224
+ )
225
+
226
+ def save_checkpoint(train_state, milestone=False):
227
+ step = int(jax.device_get(train_state.step))
228
+ metadata = dict(
229
+ step=step,
230
+ variant=variant,
231
+ flags=flags_config_dict,
232
+ roberta_config=roberta_config.to_dict(),
233
+ )
234
+ checkpointer.save_all(
235
+ train_state=train_state,
236
+ gather_fns=gather_fns,
237
+ metadata=metadata,
238
+ dataset=dataset.get_state_dict(),
239
+ milestone=milestone,
240
+ )
241
+
242
+ mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
243
+ with mesh:
244
+ train_state, restored_params = None, None
245
+ if FLAGS.load_checkpoint != '':
246
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
247
+ if load_type == 'huggingface':
248
+ restored_params = tree_apply(
249
+ shard_fns.params, roberta_config.load_pretrained(load_path)
250
+ )
251
+ train_state = None
252
+ else:
253
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
254
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
255
+ )
256
+
257
+ if train_state is None and restored_params is None:
258
+ # Initialize from scratch
259
+ train_state = sharded_init_fn(next_rng())
260
+ elif train_state is None and restored_params is not None:
261
+ # Restore from params but initialize train_state
262
+ train_state = sharded_create_trainstate_from_params(restored_params)
263
+ del restored_params
264
+
265
+ start_step = int(jax.device_get(train_state.step))
266
+
267
+ if FLAGS.save_model_freq > 0:
268
+ save_checkpoint(train_state)
269
+
270
+ sharded_rng = next_rng()
271
+
272
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
273
+
274
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
275
+ train_state, sharded_rng, metrics = sharded_train_step(
276
+ train_state, sharded_rng, batch
277
+ )
278
+
279
+ if step % FLAGS.log_freq == 0:
280
+ if FLAGS.eval_steps > 0:
281
+ eval_metric_list = []
282
+ for _ in range(FLAGS.eval_steps):
283
+ eval_batch, _ = next(eval_iterator)
284
+ sharded_rng, eval_metrics = sharded_eval_step(
285
+ train_state, sharded_rng, eval_batch
286
+ )
287
+ eval_metric_list.append(eval_metrics)
288
+ metrics.update(average_metrics(eval_metric_list))
289
+
290
+ log_metrics = {"step": step}
291
+ log_metrics.update(metrics)
292
+ log_metrics.update(dataset_metrics)
293
+ log_metrics = jax.device_get(log_metrics)
294
+ logger.log(log_metrics)
295
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
296
+
297
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
298
+ save_checkpoint(train_state, milestone=True)
299
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
300
+ save_checkpoint(train_state)
301
+
302
+ if FLAGS.save_model_freq > 0:
303
+ save_checkpoint(train_state)
304
+
305
+
306
+ if __name__ == "__main__":
307
+ mlxu.run(main)
EasyLM/optimizers.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4
+ from functools import partial
5
+ import re
6
+ import dataclasses
7
+ import random
8
+
9
+ from ml_collections.config_dict import config_dict
10
+ from ml_collections import ConfigDict
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+ from absl import logging
15
+ import optax
16
+
17
+ from EasyLM.jax_utils import float_to_dtype
18
+
19
+
20
+ class OptimizerFactory(object):
21
+ """ Configurable optax optimizer factory. """
22
+
23
+ def __init__(self):
24
+ raise NotImplementedError
25
+
26
+ @staticmethod
27
+ def get_default_config(updates=None):
28
+ config = ConfigDict()
29
+ config.accumulate_gradient_steps = 1
30
+ config.type = 'adamw'
31
+ config.palm_optimizer = PalmOptimizerFactory.get_default_config()
32
+ config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
33
+ config.lion_optimizer = LionOptimizerFactory.get_default_config()
34
+
35
+ if updates is not None:
36
+ config.update(ConfigDict(updates).copy_and_resolve_references())
37
+ return config
38
+
39
+ @classmethod
40
+ def get_optimizer(cls, config, weight_decay_mask=None):
41
+ config = cls.get_default_config(config)
42
+ if config.type == 'palm':
43
+ optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
44
+ config.palm_optimizer, weight_decay_mask
45
+ )
46
+ elif config.type == 'adamw':
47
+ optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
48
+ config.adamw_optimizer, weight_decay_mask
49
+ )
50
+ elif config.type == 'lion':
51
+ optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
52
+ config.lion_optimizer, weight_decay_mask
53
+ )
54
+ else:
55
+ raise ValueError(f'Unknown optimizer type: {config.type}')
56
+
57
+ if config.accumulate_gradient_steps > 1:
58
+ optimizer = optax.MultiSteps(
59
+ optimizer, config.accumulate_gradient_steps
60
+ )
61
+
62
+ return optimizer, optimizer_info
63
+
64
+
65
+ class PalmOptimizerFactory(object):
66
+ """ PaLM optimizer factory. This optimizer implements the optimizer
67
+ described in the PaLM paper: https://arxiv.org/abs/2204.02311
68
+ """
69
+
70
+ def __init__(self):
71
+ raise NotImplementedError
72
+
73
+ @staticmethod
74
+ def get_default_config(updates=None):
75
+ config = ConfigDict()
76
+ config.lr = 0.01
77
+ config.lr_warmup_steps = 10000
78
+ config.b1 = 0.9
79
+ config.b2 = 0.99
80
+ config.clip_gradient = 1.0
81
+ config.weight_decay = 1e-4
82
+ config.bf16_momentum = False
83
+
84
+ if updates is not None:
85
+ config.update(ConfigDict(updates).copy_and_resolve_references())
86
+ return config
87
+
88
+ @classmethod
89
+ def get_optimizer(cls, config, weight_decay_mask=None):
90
+ config = cls.get_default_config(config)
91
+
92
+ def learning_rate_schedule(step):
93
+ multiplier = config.lr / 0.01
94
+ return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
95
+
96
+ def weight_decay_schedule(step):
97
+ multiplier = config.weight_decay / 1e-4
98
+ return -multiplier * jnp.square(learning_rate_schedule(step))
99
+
100
+ optimizer_info = dict(
101
+ learning_rate_schedule=learning_rate_schedule,
102
+ weight_decay_schedule=weight_decay_schedule,
103
+ )
104
+
105
+ optimizer = optax.chain(
106
+ optax.clip_by_global_norm(config.clip_gradient),
107
+ optax.adafactor(
108
+ learning_rate=learning_rate_schedule,
109
+ multiply_by_parameter_scale=True,
110
+ momentum=config.b1,
111
+ decay_rate=config.b2,
112
+ factored=False,
113
+ clipping_threshold=None,
114
+ dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
115
+ ),
116
+ optax_add_scheduled_weight_decay(
117
+ weight_decay_schedule, weight_decay_mask
118
+ )
119
+ )
120
+ return optimizer, optimizer_info
121
+
122
+
123
+ class AdamWOptimizerFactory(object):
124
+ """ AdamW optimizer with cosine schedule. """
125
+
126
+ def __init__(self):
127
+ raise NotImplementedError
128
+
129
+ @staticmethod
130
+ def get_default_config(updates=None):
131
+ config = ConfigDict()
132
+ config.init_lr = 0.0
133
+ config.end_lr = 0.001
134
+ config.lr = 0.01
135
+ config.lr_warmup_steps = 2000
136
+ config.lr_decay_steps = 500000
137
+ config.b1 = 0.9
138
+ config.b2 = 0.95
139
+ config.clip_gradient = 1.0
140
+ config.weight_decay = 1e-4
141
+ config.bf16_momentum = False
142
+ config.multiply_by_parameter_scale = False
143
+
144
+ if updates is not None:
145
+ config.update(ConfigDict(updates).copy_and_resolve_references())
146
+ return config
147
+
148
+ @classmethod
149
+ def get_optimizer(cls, config, weight_decay_mask=None):
150
+ config = cls.get_default_config(config)
151
+
152
+ learning_rate_schedule = optax.warmup_cosine_decay_schedule(
153
+ init_value=config.init_lr,
154
+ peak_value=config.lr,
155
+ warmup_steps=config.lr_warmup_steps,
156
+ decay_steps=config.lr_decay_steps,
157
+ end_value=config.end_lr,
158
+ )
159
+
160
+ optimizer_info = dict(
161
+ learning_rate_schedule=learning_rate_schedule,
162
+ )
163
+
164
+ if config.multiply_by_parameter_scale:
165
+ optimizer = optax.chain(
166
+ optax.clip_by_global_norm(config.clip_gradient),
167
+ optax.adafactor(
168
+ learning_rate=learning_rate_schedule,
169
+ multiply_by_parameter_scale=True,
170
+ momentum=config.b1,
171
+ decay_rate=config.b2,
172
+ factored=False,
173
+ clipping_threshold=None,
174
+ dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
175
+ ),
176
+ optax_add_scheduled_weight_decay(
177
+ lambda step: -learning_rate_schedule(step) * config.weight_decay,
178
+ weight_decay_mask
179
+ )
180
+ )
181
+ else:
182
+ optimizer = optax.chain(
183
+ optax.clip_by_global_norm(config.clip_gradient),
184
+ optax.adamw(
185
+ learning_rate=learning_rate_schedule,
186
+ weight_decay=config.weight_decay,
187
+ b1=config.b1,
188
+ b2=config.b2,
189
+ mask=weight_decay_mask,
190
+ mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
191
+ ),
192
+ )
193
+
194
+ return optimizer, optimizer_info
195
+
196
+ class LionOptimizerFactory(object):
197
+ """ Lion optimizer with cosine schedule. """
198
+
199
+ def __init__(self):
200
+ raise NotImplementedError
201
+
202
+ @staticmethod
203
+ def get_default_config(updates=None):
204
+ config = ConfigDict()
205
+ config.init_lr = 0.0
206
+ config.end_lr = 0.0001
207
+ config.lr = 0.001
208
+ config.lr_warmup_steps = 2000
209
+ config.lr_decay_steps = 500000
210
+ config.b1 = 0.9
211
+ config.b2 = 0.98
212
+ config.clip_gradient = 1.0
213
+ config.weight_decay = 1e-3
214
+ config.bf16_momentum = False
215
+ config.lr_schedule_type = "warmup_cosine_decay_schedule"
216
+ config.lr_decay_rate = 0.98
217
+
218
+ if updates is not None:
219
+ config.update(ConfigDict(updates).copy_and_resolve_references())
220
+ return config
221
+
222
+ @classmethod
223
+ def get_optimizer(cls, config, weight_decay_mask=None):
224
+ config = cls.get_default_config(config)
225
+
226
+ if config.lr_schedule_type == "warmup_cosine_decay_schedule":
227
+ learning_rate_schedule = optax.warmup_cosine_decay_schedule(
228
+ init_value=config.init_lr,
229
+ peak_value=config.lr,
230
+ warmup_steps=config.lr_warmup_steps,
231
+ decay_steps=config.lr_decay_steps,
232
+ end_value=config.end_lr,
233
+ )
234
+ elif config.lr_schedule_type == "warmup_constant":
235
+ learning_rate_schedule = optax.join_schedules(
236
+ [
237
+ optax.linear_schedule(
238
+ init_value=config.init_lr,
239
+ end_value=config.lr,
240
+ transition_steps=config.lr_warmup_steps,
241
+ ),
242
+ optax.constant_schedule(config.lr),
243
+ ],
244
+ [config.lr_warmup_steps],
245
+ )
246
+ elif config.lr_schedule_type == "exponential_decay":
247
+ learning_rate_schedule = optax.exponential_decay(
248
+ init_value=config.lr,
249
+ transition_steps=config.lr_decay_steps,
250
+ decay_rate=config.lr_decay_rate,
251
+ transition_begin=0,
252
+ staircase=False,
253
+ end_value=config.end_lr,
254
+ )
255
+ else:
256
+ raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", or "exponential_decay"')
257
+
258
+ optimizer_info = dict(
259
+ learning_rate_schedule=learning_rate_schedule,
260
+ )
261
+
262
+ optimizer = optax.chain(
263
+ optax.clip_by_global_norm(config.clip_gradient),
264
+ optax.lion(
265
+ learning_rate=learning_rate_schedule,
266
+ weight_decay=config.weight_decay,
267
+ b1=config.b1,
268
+ b2=config.b2,
269
+ mask=weight_decay_mask,
270
+ mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
271
+ ),
272
+ )
273
+
274
+ return optimizer, optimizer_info
275
+
276
+
277
+ class OptaxScheduledWeightDecayState(NamedTuple):
278
+ count: jax.Array
279
+
280
+
281
+ def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
282
+ """ Apply weight decay with schedule. """
283
+
284
+ def init_fn(params):
285
+ del params
286
+ return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
287
+
288
+ def update_fn(updates, state, params):
289
+ if params is None:
290
+ raise ValueError('Params cannot be None for weight decay!')
291
+
292
+ weight_decay = schedule_fn(state.count)
293
+ updates = jax.tree_util.tree_map(
294
+ lambda g, p: g + weight_decay * p, updates, params
295
+ )
296
+ return updates, OptaxScheduledWeightDecayState(
297
+ count=optax.safe_int32_increment(state.count)
298
+ )
299
+
300
+ if mask is not None:
301
+ return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
302
+ return optax.GradientTransformation(init_fn, update_fn)
EasyLM/scripts/__init__.py ADDED
File without changes
EasyLM/scripts/benchmark_attention.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from time import time
3
+ import os
4
+ import numpy as np
5
+ import jax
6
+ import jax.flatten_util
7
+ import jax.numpy as jnp
8
+ import mlxu
9
+ from EasyLM.bpt import blockwise_attn
10
+ from EasyLM.jax_utils import (
11
+ get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
12
+ )
13
+
14
+
15
+ FLAGS, _ = mlxu.define_flags_with_default(
16
+ seed=42,
17
+ dtype='fp32',
18
+ embed_dim=2048,
19
+ n_heads=16,
20
+ ref_attn_seq_len=2048,
21
+ eff_attn_seq_len=16384,
22
+ batch_size=1,
23
+ query_chunk_size=2048,
24
+ key_chunk_size=2048,
25
+ warmup_steps=40,
26
+ steps=200,
27
+ )
28
+
29
+
30
+ def main(argv):
31
+
32
+ def random_kqv(rng_key, seq_len):
33
+ rng_generator = JaxRNG(rng_key)
34
+ kqv = []
35
+ for i in range(3):
36
+ kqv.append(
37
+ jax.random.normal(
38
+ rng_generator(),
39
+ (FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
40
+ dtype=get_float_dtype_by_name(FLAGS.dtype)
41
+ )
42
+ )
43
+ return tuple(kqv)
44
+
45
+ def reference_attn(query, key, value):
46
+ dtype = get_float_dtype_by_name(FLAGS.dtype)
47
+ query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
48
+ logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
49
+ mask_value = jnp.finfo(logits.dtype).min
50
+ _, q_seq_len, _, _ = query.shape
51
+ _, kv_seq_len, _, _ = key.shape
52
+ mask_shape = (q_seq_len, kv_seq_len)
53
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
54
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
55
+ causal_mask = (row_ids < col_ids)[None, None, :, :]
56
+ logits = logits + jnp.where(causal_mask, mask_value, 0.0)
57
+ weights = jax.nn.softmax(logits, axis=-1)
58
+ out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
59
+ return out
60
+
61
+ def efficient_attention(query, key, value):
62
+ dtype = get_float_dtype_by_name(FLAGS.dtype)
63
+ return blockwise_attn(
64
+ query, key, value,
65
+ bias=None,
66
+ deterministic=True,
67
+ dropout_rng=None,
68
+ attn_pdrop=0.0,
69
+ causal=True,
70
+ query_chunk_size=FLAGS.query_chunk_size,
71
+ key_chunk_size=FLAGS.key_chunk_size,
72
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
73
+ policy=jax.checkpoint_policies.nothing_saveable(),
74
+ precision=None,
75
+ float32_logits=True,
76
+ prevent_cse=True,
77
+ )
78
+
79
+
80
+ @partial(jax.jit, static_argnums=(1,))
81
+ def reference_attn_forward_backward(rng_key, seq_len):
82
+ @partial(jax.grad, argnums=(0, 1, 2))
83
+ @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
84
+ def grad_fn(query, key, value):
85
+ out = reference_attn(query, key, value)
86
+ return jnp.mean(out)
87
+
88
+ query, key, value = random_kqv(rng_key, seq_len)
89
+ return jax.flatten_util.ravel_pytree(
90
+ grad_fn(query, key, value)[1]
91
+ )[0].mean()
92
+
93
+ @partial(jax.jit, static_argnums=(1,))
94
+ def efficient_attn_forward_backward(rng_key, seq_len):
95
+ @partial(jax.grad, argnums=(0, 1, 2))
96
+ def grad_fn(query, key, value):
97
+ out = efficient_attention(query, key, value)
98
+ return jnp.mean(out)
99
+
100
+ query, key, value = random_kqv(rng_key, seq_len)
101
+ return jax.flatten_util.ravel_pytree(
102
+ grad_fn(query, key, value)[1]
103
+ )[0].mean()
104
+
105
+
106
+ set_random_seed(FLAGS.seed)
107
+
108
+ jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
109
+ jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
110
+
111
+ all_results = []
112
+ for i in range(FLAGS.warmup_steps):
113
+ all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
114
+ jax.block_until_ready(all_results)
115
+
116
+ start_time = time()
117
+ all_results = []
118
+ for i in range(FLAGS.steps):
119
+ all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
120
+
121
+ jax.block_until_ready(all_results)
122
+ elapsed_time_ref_attn = time() - start_time
123
+ print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
124
+
125
+
126
+ all_results = []
127
+ for i in range(FLAGS.warmup_steps):
128
+ all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
129
+ jax.block_until_ready(all_results)
130
+
131
+
132
+ start_time = time()
133
+ all_results = []
134
+ for i in range(FLAGS.steps):
135
+ all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
136
+
137
+ jax.block_until_ready(all_results)
138
+ elapsed_time_efficient_attn = time() - start_time
139
+ print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
140
+
141
+ flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
142
+ efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
143
+ print(f'Efficiency: {efficiency:.3f}')
144
+
145
+
146
+ if __name__ == '__main__':
147
+ mlxu.run(main)
148
+
149
+
150
+
EasyLM/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts model checkpoint trained by EsayLM to a standard
2
+ # mspack checkpoint that can be loaded by huggingface transformers or
3
+ # flax.serialization.msgpack_restore. Such conversion allows models to be
4
+ # used by other frameworks that integrate with huggingface transformers.
5
+
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ import numpy as np
10
+ import mlxu
11
+ import jax.numpy as jnp
12
+ import flax.serialization
13
+ from EasyLM.checkpoint import StreamingCheckpointer
14
+ from EasyLM.jax_utils import float_to_dtype
15
+
16
+
17
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
18
+ load_checkpoint='',
19
+ output_file='',
20
+ streaming=False,
21
+ float_dtype='bf16',
22
+ )
23
+
24
+
25
+ def main(argv):
26
+ assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
27
+ params = StreamingCheckpointer.load_trainstate_checkpoint(
28
+ FLAGS.load_checkpoint, disallow_trainstate=True
29
+ )[1]['params']
30
+
31
+ if FLAGS.streaming:
32
+ StreamingCheckpointer.save_train_state_to_file(
33
+ params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
34
+ )
35
+ else:
36
+ params = float_to_dtype(params, FLAGS.float_dtype)
37
+ with mlxu.open_file(FLAGS.output, 'wb') as fout:
38
+ fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
39
+
40
+
41
+ if __name__ == "__main__":
42
+ mlxu.run(main)
EasyLM/scripts/diff_checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts model checkpoint trained by EsayLM to a standard
2
+ # mspack checkpoint that can be loaded by huggingface transformers or
3
+ # flax.serialization.msgpack_restore. Such conversion allows models to be
4
+ # used by other frameworks that integrate with huggingface transformers.
5
+
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ import numpy as np
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import flax.serialization
13
+ import mlxu
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+ from EasyLM.jax_utils import float_to_dtype
16
+
17
+
18
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
19
+ recover_diff=False,
20
+ load_base_checkpoint='',
21
+ load_target_checkpoint='',
22
+ output_file='',
23
+ streaming=True,
24
+ float_dtype='bf16',
25
+ )
26
+
27
+
28
+ def main(argv):
29
+ assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
30
+ assert FLAGS.output_file != ''
31
+ base_params = StreamingCheckpointer.load_trainstate_checkpoint(
32
+ FLAGS.load_base_checkpoint, disallow_trainstate=True
33
+ )[1]['params']
34
+
35
+ target_params = StreamingCheckpointer.load_trainstate_checkpoint(
36
+ FLAGS.load_target_checkpoint, disallow_trainstate=True
37
+ )[1]['params']
38
+
39
+ if FLAGS.recover_diff:
40
+ params = jax.tree_util.tree_map(
41
+ lambda b, t: b + t, base_params, target_params
42
+ )
43
+ else:
44
+ params = jax.tree_util.tree_map(
45
+ lambda b, t: t - b, base_params, target_params
46
+ )
47
+
48
+ if FLAGS.streaming:
49
+ StreamingCheckpointer.save_train_state_to_file(
50
+ params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
51
+ )
52
+ else:
53
+ params = float_to_dtype(params, FLAGS.float_dtype)
54
+ with mlxu.open_file(FLAGS.output, 'wb') as fout:
55
+ fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
56
+
57
+
58
+ if __name__ == "__main__":
59
+ mlxu.run(main)
EasyLM/scripts/lm_eval_harness.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script runs lm_eval_harness evaluations against a served language model.
2
+ # Typically, you need to run a language model server first, e.g.:
3
+ # python -m EasyLM.models.gptj.gptj_serve ...
4
+
5
+ import dataclasses
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ from tqdm import tqdm, trange
10
+ import numpy as np
11
+ import mlxu
12
+
13
+ from flax.traverse_util import flatten_dict
14
+ from lm_eval import evaluator, tasks
15
+ from lm_eval.base import LM
16
+
17
+ from EasyLM.serving import LMClient
18
+
19
+
20
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
21
+ tasks='wsc,piqa,winogrande,openbookqa,logiqa',
22
+ shots=0,
23
+ limit=0,
24
+ write_out=False,
25
+ lm_client=LMClient.get_default_config(),
26
+ logger=mlxu.WandBLogger.get_default_config(),
27
+ )
28
+
29
+
30
+ class LMEvalHarnessInterface(LM):
31
+
32
+ def __init__(self, lm_client):
33
+ self.lm_client = lm_client
34
+
35
+ def greedy_until(self, inputs):
36
+ prefix, until = zip(*inputs)
37
+ return self.lm_client.greedy_until(prefix, until)
38
+
39
+ def loglikelihood_rolling(self, inputs):
40
+ loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
41
+ return list(zip(loglikelihood, is_greedy))
42
+
43
+ def loglikelihood(self, inputs):
44
+ prefix, text = zip(*inputs)
45
+ loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
46
+ return list(zip(loglikelihood, is_greedy))
47
+
48
+
49
+ def main(argv):
50
+ logger = mlxu.WandBLogger(
51
+ config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
52
+ )
53
+ model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
54
+ task_list = FLAGS.tasks.split(',')
55
+ results = evaluator.evaluate(
56
+ model, tasks.get_task_dict(task_list), False, FLAGS.shots,
57
+ limit=None if FLAGS.limit <= 0 else FLAGS.limit,
58
+ write_out=FLAGS.write_out,
59
+ )
60
+ logger.log(flatten_dict(results['results'], sep='/'))
61
+ pprint.pprint(results)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ mlxu.run(main)
EasyLM/scripts/lm_eval_json.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import mlxu
3
+ from EasyLM.serving import LMClient
4
+
5
+
6
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
7
+ input_file='',
8
+ output_file='',
9
+ prefix_field='prefix',
10
+ text_field='text',
11
+ until_field='until',
12
+ eval_type='loglikelihood',
13
+ lm_client=LMClient.get_default_config(),
14
+ )
15
+
16
+
17
+ def main(argv):
18
+ lm_client = LMClient(FLAGS.lm_client)
19
+ with mlxu.open_file(FLAGS.input_file, 'r') as fin:
20
+ input_data = json.load(fin)
21
+
22
+ if FLAGS.eval_type == 'loglikelihood':
23
+ prefix = input_data[FLAGS.prefix_field]
24
+ text = input_data[FLAGS.text_field]
25
+ loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
26
+ output_data = {
27
+ 'loglikelihood': loglikelihoods,
28
+ 'is_greedy': is_greedys,
29
+ }
30
+ elif FLAGS.eval_type == 'loglikelihood_rolling':
31
+ text = input_data[FLAGS.text_field]
32
+ loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
33
+ output_data = {
34
+ 'loglikelihood': loglikelihoods,
35
+ 'is_greedy': is_greedys,
36
+ }
37
+ elif FLAGS.eval_type == 'greedy_until':
38
+ prefix = input_data[FLAGS.prefix_field]
39
+ until = input_data[FLAGS.until_field]
40
+ output_data = {'output_text': lm_client.greedy_until(prefix, until)}
41
+ elif FLAGS.eval_type == 'generate':
42
+ prefix = input_data[FLAGS.prefix_field]
43
+ output_data = {'output_text': lm_client.generate(prefix)}
44
+ else:
45
+ raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
46
+
47
+ with mlxu.open_file(FLAGS.output_file, 'w') as fout:
48
+ json.dump(output_data, fout)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ mlxu.run(main)
EasyLM/serving.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ from functools import partial
4
+ import re
5
+ import os
6
+ from threading import Lock
7
+ import urllib
8
+ import time
9
+ from typing import List, Optional, Union
10
+
11
+ from pydantic import BaseModel
12
+ import absl.logging
13
+ from tqdm import tqdm, trange
14
+ import numpy as np
15
+ import mlxu
16
+ from ml_collections import ConfigDict
17
+ import uvicorn
18
+ from fastapi import FastAPI
19
+ import gradio as gr
20
+ import requests
21
+ from requests.exceptions import Timeout, ConnectionError
22
+
23
+
24
+ class InferenceRequest(BaseModel):
25
+ prefix_text: Optional[List[str]] = None
26
+ text: Optional[List[str]] = None
27
+ until: Optional[Union[List[str], List[List[str]]]] = None
28
+ temperature: Optional[float] = None
29
+
30
+
31
+ class ChatRequest(BaseModel):
32
+ prompt: str
33
+ context: str = ''
34
+ temperature: Optional[float] = None
35
+
36
+
37
+ class LMServer(object):
38
+ """ HTTP server for serving langauge models. """
39
+
40
+ @staticmethod
41
+ def get_default_config(updates=None):
42
+ config = ConfigDict()
43
+ config.host = '0.0.0.0'
44
+ config.port = 5007
45
+ config.batch_size = 1
46
+ config.logging = False
47
+ config.pre_compile = 'loglikelihood'
48
+ config.default_temperature = 1.0
49
+ config.greedy_until_max_length = 5000
50
+ config.prepend_to_prefix = ''
51
+ config.append_to_prefix = ''
52
+ config.prepend_to_text = ''
53
+ config.append_to_text = ''
54
+ config.chat_prepend_text = ''
55
+ config.chat_user_prefix = ''
56
+ config.chat_user_suffix = ''
57
+ config.chat_lm_prefix = ''
58
+ config.chat_lm_suffix = ''
59
+ config.notes = ''
60
+
61
+ if updates is not None:
62
+ config.update(ConfigDict(updates).copy_and_resolve_references())
63
+ return config
64
+
65
+ def __init__(self, config):
66
+ self.config = self.get_default_config(config)
67
+ self.lock = Lock()
68
+ self.app = FastAPI()
69
+ self.app.post('/loglikelihood')(self.serve_loglikelihood)
70
+ self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
71
+ self.app.post('/generate')(self.serve_generate)
72
+ self.app.post('/greedy-until')(self.serve_greedy_until)
73
+ self.app.post('/chat')(self.serve_chat)
74
+ self.app.get('/ready')(self.serve_ready)
75
+ self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
76
+
77
+ @staticmethod
78
+ def loglikelihood(prefix_text, text):
79
+ raise NotImplementedError()
80
+
81
+ @staticmethod
82
+ def loglikelihood_rolling(text):
83
+ raise NotImplementedError()
84
+
85
+ @staticmethod
86
+ def generate(text, temperature):
87
+ raise NotImplementedError()
88
+
89
+ @staticmethod
90
+ def greedy_until(prefix_text, until, max_length):
91
+ raise NotImplementedError()
92
+
93
+ @staticmethod
94
+ def to_list(x):
95
+ if isinstance(x, np.ndarray):
96
+ return x.tolist()
97
+ return x
98
+
99
+ def serve_ready(self):
100
+ return 'Ready!\n'
101
+
102
+ def serve_loglikelihood(self, data: InferenceRequest):
103
+ with self.lock:
104
+ if self.config.logging:
105
+ absl.logging.info(
106
+ '\n========= Serving Log Likelihood Request ========= \n'
107
+ + pprint.pformat(data) + '\n'
108
+ )
109
+
110
+ if data.prefix_text is None:
111
+ data.prefix_text = ['' for _ in data.text]
112
+
113
+ prefix_text = [
114
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
115
+ for p in data.prefix_text
116
+ ]
117
+ text = [
118
+ self.config.prepend_to_text + t + self.config.append_to_text
119
+ for t in data.text
120
+ ]
121
+
122
+ log_likelihood = []
123
+ is_greedy = []
124
+ for i in trange(0, len(text), self.config.batch_size, ncols=0):
125
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
126
+ batch_text = text[i:i + self.config.batch_size]
127
+ batch_size = len(batch_text)
128
+
129
+ if batch_size < self.config.batch_size:
130
+ extra = self.config.batch_size - batch_size
131
+ batch_prefix_text.extend(['a' for _ in range(extra)])
132
+ batch_text.extend(['a' for _ in range(extra)])
133
+
134
+ batch_log_likelihood, batch_is_greedy = self.loglikelihood(
135
+ batch_prefix_text, batch_text
136
+ )
137
+ batch_log_likelihood = self.to_list(batch_log_likelihood)
138
+ batch_is_greedy = self.to_list(batch_is_greedy)
139
+ log_likelihood.extend(batch_log_likelihood[:batch_size])
140
+ is_greedy.extend(batch_is_greedy[:batch_size])
141
+
142
+ output = {
143
+ 'prefix_text': data.prefix_text,
144
+ 'text': data.text,
145
+ 'log_likelihood': log_likelihood,
146
+ 'is_greedy': is_greedy,
147
+ }
148
+ if self.config.logging:
149
+ absl.logging.info(
150
+ '\n========= Output ========= \n'
151
+ + pprint.pformat(output) + '\n'
152
+ )
153
+
154
+ return output
155
+
156
+ def serve_loglikelihood_rolling(self, data: InferenceRequest):
157
+ with self.lock:
158
+ if self.config.logging:
159
+ absl.logging.info(
160
+ '\n========= Serving Log Likelihood Request ========= \n'
161
+ + pprint.pformat(data) + '\n'
162
+ )
163
+
164
+ text = [
165
+ self.config.prepend_to_text + t + self.config.append_to_text
166
+ for t in data.text
167
+ ]
168
+ log_likelihood = []
169
+ is_greedy = []
170
+ for i in trange(0, len(text), self.config.batch_size, ncols=0):
171
+ batch_text = text[i:i + self.config.batch_size]
172
+ batch_size = len(batch_text)
173
+
174
+ if batch_size < self.config.batch_size:
175
+ extra = self.config.batch_size - batch_size
176
+ batch_text.extend(['a' for _ in range(extra)])
177
+
178
+ batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
179
+ batch_text
180
+ )
181
+ batch_log_likelihood = self.to_list(batch_log_likelihood)
182
+ batch_is_greedy = self.to_list(batch_is_greedy)
183
+ log_likelihood.extend(batch_log_likelihood[:batch_size])
184
+ is_greedy.extend(batch_is_greedy[:batch_size])
185
+
186
+ output = {
187
+ 'text': data.text,
188
+ 'log_likelihood': log_likelihood,
189
+ 'is_greedy': is_greedy,
190
+ }
191
+ if self.config.logging:
192
+ absl.logging.info(
193
+ '\n========= Output ========= \n'
194
+ + pprint.pformat(output) + '\n'
195
+ )
196
+
197
+ return output
198
+
199
+ def serve_generate(self, data: InferenceRequest):
200
+ with self.lock:
201
+ if self.config.logging:
202
+ absl.logging.info(
203
+ '\n========= Serving Generate Request ========= \n'
204
+ + pprint.pformat(data) + '\n'
205
+ )
206
+ prefix_text = [
207
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
208
+ for p in data.prefix_text
209
+ ]
210
+
211
+ if data.temperature is None:
212
+ data.temperature = self.config.default_temperature
213
+
214
+ output_text = []
215
+ for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
216
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
217
+ batch_size = len(batch_prefix_text)
218
+
219
+ if batch_size < self.config.batch_size:
220
+ extra = self.config.batch_size - batch_size
221
+ batch_prefix_text.extend(['a' for _ in range(extra)])
222
+
223
+ batch_output_text = self.generate(
224
+ batch_prefix_text,
225
+ temperature=data.temperature,
226
+ )
227
+ output_text.extend(self.to_list(batch_output_text)[:batch_size])
228
+
229
+ output = {
230
+ 'prefix_text': data.prefix_text,
231
+ 'output_text': output_text,
232
+ 'temperature': data.temperature,
233
+ }
234
+ if self.config.logging:
235
+ absl.logging.info(
236
+ '\n========= Output ========= \n'
237
+ + pprint.pformat(output) + '\n'
238
+ )
239
+ return output
240
+
241
+ def serve_greedy_until(self, data: InferenceRequest):
242
+ with self.lock:
243
+ if self.config.logging:
244
+ absl.logging.info(
245
+ '\n========= Serving Greedy Until Request ========= \n'
246
+ + pprint.pformat(data) + '\n'
247
+ )
248
+ prefix_text = [
249
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
250
+ for p in data.prefix_text
251
+ ]
252
+ until = data.until
253
+ max_length = self.config.greedy_until_max_length
254
+
255
+ output_text = []
256
+ for i in range(0, len(prefix_text), self.config.batch_size):
257
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
258
+ batch_until = until[i:i + self.config.batch_size]
259
+ batch_size = len(batch_prefix_text)
260
+
261
+ batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
262
+ output_text.extend(self.to_list(batch_output_text)[:batch_size])
263
+
264
+ output = {
265
+ 'prefix_text': data.prefix_text,
266
+ 'until': data.until,
267
+ 'max_length': max_length,
268
+ 'output_text': output_text,
269
+ }
270
+ if self.config.logging:
271
+ absl.logging.info(
272
+ '\n========= Output ========= \n'
273
+ + pprint.pformat(output) + '\n'
274
+ )
275
+ return output
276
+
277
+ def process_chat(self, prompt, context, temperature):
278
+ context = (
279
+ context + self.config.chat_user_prefix
280
+ + prompt + self.config.chat_user_suffix
281
+ + self.config.chat_lm_prefix
282
+ )
283
+ response = self.generate(
284
+ [self.config.chat_prepend_text + context],
285
+ temperature=float(temperature),
286
+ )[0]
287
+ context = context + response + self.config.chat_lm_suffix
288
+ return response, context
289
+
290
+ def serve_chat(self, data: ChatRequest):
291
+ if data.temperature is None:
292
+ data.temperature = self.config.default_temperature
293
+ response, context = self.process_chat(
294
+ data.prompt, data.context,
295
+ temperature=data.temperature,
296
+ )
297
+ return {
298
+ 'response': response,
299
+ 'context': context,
300
+ 'temperature': data.temperature,
301
+ }
302
+
303
+ def create_chat_app(self):
304
+ with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
305
+ gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
306
+ gr.Markdown(self.config.notes)
307
+ chatbot = gr.Chatbot(label='Chat history')
308
+ msg = gr.Textbox(
309
+ placeholder='Type your message here...',
310
+ show_label=False
311
+ )
312
+ with gr.Row():
313
+ send = gr.Button('Send')
314
+ regenerate = gr.Button('Regenerate', interactive=False)
315
+ clear = gr.Button('Reset')
316
+
317
+ temp_slider = gr.Slider(
318
+ label='Temperature', minimum=0, maximum=2.0,
319
+ value=self.config.default_temperature
320
+ )
321
+
322
+ context_state = gr.State(['', ''])
323
+
324
+ def user_fn(user_message, history, context):
325
+ return {
326
+ msg: gr.update(value='', interactive=False),
327
+ clear: gr.update(interactive=False),
328
+ send: gr.update(interactive=False),
329
+ regenerate: gr.update(interactive=False),
330
+ chatbot: history + [[user_message, None]],
331
+ context_state: [context[1], context[1]],
332
+ }
333
+
334
+ def model_fn(history, context, temperature):
335
+ history[-1][1], new_context = self.process_chat(
336
+ history[-1][0], context[0], temperature
337
+ )
338
+ return {
339
+ msg: gr.update(value='', interactive=True),
340
+ clear: gr.update(interactive=True),
341
+ send: gr.update(interactive=True),
342
+ chatbot: history,
343
+ context_state: [context[0], new_context],
344
+ regenerate: gr.update(interactive=True),
345
+ }
346
+
347
+ def regenerate_fn():
348
+ return {
349
+ msg: gr.update(value='', interactive=False),
350
+ clear: gr.update(interactive=False),
351
+ send: gr.update(interactive=False),
352
+ regenerate: gr.update(interactive=False),
353
+ }
354
+
355
+ def clear_fn():
356
+ return {
357
+ chatbot: None,
358
+ msg: '',
359
+ context_state: ['', ''],
360
+ regenerate: gr.update(interactive=False),
361
+ }
362
+
363
+ msg.submit(
364
+ user_fn,
365
+ inputs=[msg, chatbot, context_state],
366
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
367
+ queue=False
368
+ ).then(
369
+ model_fn,
370
+ inputs=[chatbot, context_state, temp_slider],
371
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
372
+ queue=True
373
+ )
374
+ send.click(
375
+ user_fn,
376
+ inputs=[msg, chatbot, context_state],
377
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
378
+ queue=False
379
+ ).then(
380
+ model_fn,
381
+ inputs=[chatbot, context_state, temp_slider],
382
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
383
+ queue=True
384
+ )
385
+ regenerate.click(
386
+ regenerate_fn,
387
+ inputs=None,
388
+ outputs=[msg, clear, send, regenerate],
389
+ queue=False
390
+ ).then(
391
+ model_fn,
392
+ inputs=[chatbot, context_state, temp_slider],
393
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
394
+ queue=True
395
+ )
396
+ clear.click(
397
+ clear_fn,
398
+ inputs=None,
399
+ outputs=[chatbot, msg, context_state, regenerate],
400
+ queue=False
401
+ )
402
+
403
+ gradio_chatbot.queue(concurrency_count=1)
404
+ return gradio_chatbot
405
+
406
+ def run(self):
407
+ if self.config.pre_compile != '':
408
+ if self.config.pre_compile == 'all':
409
+ pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
410
+ else:
411
+ pre_compile = self.config.pre_compile.split(',')
412
+
413
+ pre_compile_data = ['a' for _ in range(self.config.batch_size)]
414
+ for task in pre_compile:
415
+ if task == 'loglikelihood':
416
+ self.loglikelihood(pre_compile_data, pre_compile_data)
417
+ self.loglikelihood_rolling(pre_compile_data)
418
+ elif task == 'generate':
419
+ self.generate(pre_compile_data, 1.0)
420
+ elif task == 'greedy_until':
421
+ self.greedy_until(
422
+ pre_compile_data, pre_compile_data,
423
+ self.config.greedy_until_max_length
424
+ )
425
+ elif task == 'chat':
426
+ self.process_chat('a', 'a', 1.0)
427
+ else:
428
+ raise ValueError(f'Invalid precompile task: {task}!')
429
+
430
+ uvicorn.run(self.app, host=self.config.host, port=self.config.port)
431
+
432
+
433
+ class LMClient(object):
434
+ """ A simple client for the LM server. """
435
+
436
+ @staticmethod
437
+ def get_default_config(updates=None):
438
+ config = ConfigDict()
439
+ config.url = 'http://localhost:5007'
440
+ config.batch_size = 1
441
+ config.wait_for_ready = True
442
+ config.dummy = False
443
+
444
+ if updates is not None:
445
+ config.update(ConfigDict(updates).copy_and_resolve_references())
446
+ return config
447
+
448
+ def __init__(self, config=None):
449
+ self.config = self.get_default_config(config)
450
+ if self.config.wait_for_ready:
451
+ self.wait_for_ready()
452
+
453
+ def wait_for_ready(self):
454
+ if self.config.dummy:
455
+ return
456
+ while True:
457
+ try:
458
+ requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
459
+ return
460
+ except (Timeout, ConnectionError) as e:
461
+ time.sleep(10)
462
+
463
+ @staticmethod
464
+ def batched(iterator, batch_size):
465
+ batch = []
466
+ for example in iterator:
467
+ batch.append(example)
468
+ if len(batch) == batch_size:
469
+ yield batch
470
+ batch = []
471
+ if len(batch) > 0:
472
+ yield batch
473
+
474
+ def loglikelihood(self, prefix, text):
475
+ prefix, text = list(prefix), list(text)
476
+ if self.config.dummy:
477
+ return [-1.0 for _ in text], [False for _ in text]
478
+
479
+ log_likelihood = []
480
+ is_greedy = []
481
+
482
+ batched_iterator = list(zip(
483
+ self.batched(prefix, self.config.batch_size),
484
+ self.batched(text, self.config.batch_size)
485
+ ))
486
+ for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
487
+ response = requests.post(
488
+ urllib.parse.urljoin(self.config.url, 'loglikelihood'),
489
+ json={'prefix_text': batch_prefix, 'text': batch_text}
490
+ ).json()
491
+ log_likelihood.extend(response['log_likelihood'])
492
+ is_greedy.extend(response['is_greedy'])
493
+
494
+ return log_likelihood, is_greedy
495
+
496
+ def loglikelihood_rolling(self, text):
497
+ text = list(text)
498
+ if self.config.dummy:
499
+ return [-1.0 for _ in text], [False for _ in text]
500
+
501
+ log_likelihood = []
502
+ is_greedy = []
503
+ batched_iterator = list(self.batched(text, self.config.batch_size))
504
+ for batch_text in tqdm(batched_iterator, ncols=0):
505
+ response = requests.post(
506
+ urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
507
+ json={'text': batch_text}
508
+ ).json()
509
+ log_likelihood.extend(response['log_likelihood'])
510
+ is_greedy.extend(response['is_greedy'])
511
+ return log_likelihood, is_greedy
512
+
513
+ def greedy_until(self, prefix, until):
514
+ prefix, until = list(prefix), list(until)
515
+ if self.config.dummy:
516
+ results = []
517
+ for u in until:
518
+ if isinstance(u, str):
519
+ results.append('dummy text ' + u)
520
+ else:
521
+ results.append('dummy text ' + u[0])
522
+ return results
523
+
524
+ batched_iterator = list(zip(
525
+ self.batched(prefix, self.config.batch_size),
526
+ self.batched(until, self.config.batch_size),
527
+ ))
528
+ output_text = []
529
+ for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
530
+ response = requests.post(
531
+ urllib.parse.urljoin(self.config.url, 'greedy-until'),
532
+ json={'prefix_text': batch_prefix, 'until': batch_until}
533
+ ).json()
534
+ output_text.extend(response['output_text'])
535
+ return output_text
536
+
537
+ def generate(self, prefix, temperature=None):
538
+ prefix = list(prefix)
539
+ if self.config.dummy:
540
+ return ['' for _ in prefix]
541
+
542
+ output_text = []
543
+ batched_iterator = list(self.batched(prefix, self.config.batch_size))
544
+ for batch_prefix in tqdm(batched_iterator, ncols=0):
545
+ response = requests.post(
546
+ urllib.parse.urljoin(self.config.url, 'generate'),
547
+ json={
548
+ 'prefix_text': batch_prefix,
549
+ 'temperature': temperature,
550
+ }
551
+ ).json()
552
+ output_text.extend(response['output_text'])
553
+ return output_text
554
+
555
+ def chat(self, prompt, context, temperature=None):
556
+ if self.config.dummy:
557
+ return ''
558
+ response = requests.post(
559
+ urllib.parse.urljoin(self.config.url, 'chat'),
560
+ json={
561
+ 'prompt': prompt,
562
+ 'context': context,
563
+ 'temperature': temperature,
564
+ }
565
+ ).json()
566
+ return response['response'], response['context']
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 3200,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 8640,
12
+ "max_position_embeddings": 2048,
13
+ "model_type": "llama",
14
+ "num_attention_heads": 32,
15
+ "num_hidden_layers": 26,
16
+ "num_key_value_heads": 32,
17
+ "pretraining_tp": 1,
18
+ "rms_norm_eps": 1e-06,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.34.0.dev0",
24
+ "use_cache": true,
25
+ "vocab_size": 64256
26
+ }
convert_to_hf_model.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ JAX_PLATFORM_NAME=cpu python3 -m EasyLM.models.llama.convert_easylm_to_hf \
2
+ --load_checkpoint='' \
3
+ --model_size='3b' \
4
+ --output_dir='./'
pretrain_llama_3b.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Put your WANDB API key here to enable logging to wandb.
4
+ export WANDB_API_KEY=''
5
+
6
+ # TPU specific flags to improve training throughput
7
+ export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
8
+
9
+
10
+ python3 -m EasyLM.models.llama.llama_train \
11
+ --jax_distributed.initialize_jax_distributed=True \
12
+ --mesh_dim='1,-1,1' \
13
+ --dtype='bf16' \
14
+ --total_steps=900000 \
15
+ --eval_freq=50000 \
16
+ --log_freq=1000 \
17
+ --save_model_freq=2000 \
18
+ --save_milestone_freq=50000 \
19
+ --load_llama_config='3b' \
20
+ --update_llama_config='' \
21
+ --load_dataset_state='' \
22
+ --load_checkpoint='' \
23
+ --tokenizer.pretrained_model_name_or_path='./' \
24
+ --optimizer.type='lion' \
25
+ --optimizer.lion_optimizer.weight_decay=1.0 \
26
+ --optimizer.lion_optimizer.lr_schedule_type='warmup_constant' \
27
+ --optimizer.lion_optimizer.lr=3e-4 \
28
+ --optimizer.lion_optimizer.end_lr=3e-5 \
29
+ --optimizer.lion_optimizer.lr_warmup_steps=60000 \
30
+ --optimizer.lion_optimizer.lr_decay_steps=100000 \
31
+ --optimizer.lion_optimizer.bf16_momentum=True \
32
+ --train_dataset.type='huggingface' \
33
+ --train_dataset.text_processor.fields='text' \
34
+ --train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
35
+ --train_dataset.huggingface_dataset.split='train' \
36
+ --train_dataset.huggingface_dataset.seq_length=2048 \
37
+ --train_dataset.huggingface_dataset.batch_size=64 \
38
+ --eval_dataset.type='huggingface' \
39
+ --eval_dataset.text_processor.fields='text' \
40
+ --eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
41
+ --eval_dataset.huggingface_dataset.split='validation' \
42
+ --eval_dataset.huggingface_dataset.seq_length=2048 \
43
+ --eval_dataset.huggingface_dataset.batch_size=64 \
44
+ --checkpointer.save_optimizer_state=True \
45
+ --logger.online=True \
46
+ --logger.prefix='EasyLM' \
47
+ --logger.project="llama-3b-finnish-v2" \
48
+ --logger.output_dir="gs://finnish-nlp-research-us/llama-3b-v2-checkpoint" \
49
+ --logger.wandb_dir="./"
50
+