aapot
commited on
Commit
•
5a63fc6
1
Parent(s):
64db1e7
Add easylm training code
Browse files- .gitignore +1 -0
- EasyLM/__init__.py +0 -0
- EasyLM/bpt.py +228 -0
- EasyLM/checkpoint.py +212 -0
- EasyLM/data.py +431 -0
- EasyLM/jax_utils.py +403 -0
- EasyLM/models/__init__.py +0 -0
- EasyLM/models/gptj/__init__.py +0 -0
- EasyLM/models/gptj/gptj_model.py +1054 -0
- EasyLM/models/gptj/gptj_serve.py +396 -0
- EasyLM/models/gptj/gptj_train.py +272 -0
- EasyLM/models/llama/convert_easylm_to_hf.py +338 -0
- EasyLM/models/llama/convert_hf_to_easylm.py +196 -0
- EasyLM/models/llama/convert_torch_to_easylm.py +68 -0
- EasyLM/models/llama/llama_model.py +1360 -0
- EasyLM/models/llama/llama_serve.py +386 -0
- EasyLM/models/llama/llama_train.py +268 -0
- EasyLM/models/roberta/__init__.py +0 -0
- EasyLM/models/roberta/roberta_model.py +1694 -0
- EasyLM/models/roberta/roberta_train.py +307 -0
- EasyLM/optimizers.py +302 -0
- EasyLM/scripts/__init__.py +0 -0
- EasyLM/scripts/benchmark_attention.py +150 -0
- EasyLM/scripts/convert_checkpoint.py +42 -0
- EasyLM/scripts/diff_checkpoint.py +59 -0
- EasyLM/scripts/lm_eval_harness.py +65 -0
- EasyLM/scripts/lm_eval_json.py +52 -0
- EasyLM/serving.py +566 -0
- config.json +26 -0
- convert_to_hf_model.sh +4 -0
- pretrain_llama_3b.sh +50 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
EasyLM/__init__.py
ADDED
File without changes
|
EasyLM/bpt.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
3 |
+
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
|
4 |
+
"""
|
5 |
+
|
6 |
+
import functools
|
7 |
+
from typing import NamedTuple
|
8 |
+
|
9 |
+
import flax.linen as nn
|
10 |
+
import jax
|
11 |
+
import jax.lax as lax
|
12 |
+
import jax.numpy as jnp
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
"""
|
16 |
+
Computing ffn blockwise without materializing the large hidden tensor, training
|
17 |
+
4x longer sequences than the memory-efficient transformer.
|
18 |
+
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
|
19 |
+
"""
|
20 |
+
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
|
21 |
+
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
|
22 |
+
# inputs: (batch, seq_len, dim)
|
23 |
+
# chunk_size: the chunk size to split the sequence
|
24 |
+
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
|
25 |
+
def scan_ffn(remat_ffn, carry, hidden_states):
|
26 |
+
outputs = remat_ffn(hidden_states, deterministic=deterministic)
|
27 |
+
return carry, outputs
|
28 |
+
scan_axis = inputs.ndim - 2
|
29 |
+
_, res = nn.scan(
|
30 |
+
scan_ffn,
|
31 |
+
variable_broadcast="params",
|
32 |
+
split_rngs={"params": False, "dropout": True},
|
33 |
+
in_axes=scan_axis,
|
34 |
+
out_axes=scan_axis,
|
35 |
+
)(remat_ffn, None, inputs)
|
36 |
+
res = rearrange(res, 'b c n d -> b (c n) d')
|
37 |
+
return res
|
38 |
+
|
39 |
+
|
40 |
+
"""
|
41 |
+
Compute attention blockwise without materializing the full attention matrix,
|
42 |
+
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
|
43 |
+
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
|
44 |
+
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
45 |
+
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
|
46 |
+
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
|
47 |
+
"""
|
48 |
+
def blockwise_attn(
|
49 |
+
query, key, value,
|
50 |
+
bias=None,
|
51 |
+
deterministic=True,
|
52 |
+
dropout_rng=None,
|
53 |
+
attn_pdrop=0.0,
|
54 |
+
causal=True,
|
55 |
+
query_chunk_size=2048,
|
56 |
+
key_chunk_size=2048,
|
57 |
+
dtype=jnp.float32,
|
58 |
+
policy=jax.checkpoint_policies.nothing_saveable(),
|
59 |
+
precision=None,
|
60 |
+
float32_logits=True,
|
61 |
+
prevent_cse=True,
|
62 |
+
):
|
63 |
+
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
|
64 |
+
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
|
65 |
+
# causal: whether to use causal mask
|
66 |
+
# policy: one of jax.checkpoint_policies
|
67 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
68 |
+
if float32_logits:
|
69 |
+
query = query.astype(jnp.float32)
|
70 |
+
key = key.astype(jnp.float32)
|
71 |
+
|
72 |
+
batch, q_len, num_heads, dim_per_head = query.shape
|
73 |
+
batch, kv_len, num_heads, dim_per_head = key.shape
|
74 |
+
batch, kv_len, num_heads, dim_per_head = value.shape
|
75 |
+
|
76 |
+
num_q = q_len // query_chunk_size
|
77 |
+
num_kv = kv_len // key_chunk_size
|
78 |
+
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
|
79 |
+
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
80 |
+
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
81 |
+
|
82 |
+
query = jnp.moveaxis(query, 1, 0)
|
83 |
+
key = jnp.moveaxis(key, 1, 0)
|
84 |
+
value = jnp.moveaxis(value, 1, 0)
|
85 |
+
|
86 |
+
if bias is not None:
|
87 |
+
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
|
88 |
+
assert bias_dim == 1 or bias_dim == broadcast_dim
|
89 |
+
if not deterministic and attn_pdrop > 0.0:
|
90 |
+
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
|
91 |
+
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
|
92 |
+
else:
|
93 |
+
attn_dropout = None
|
94 |
+
|
95 |
+
_chunk_bias_fn = functools.partial(
|
96 |
+
_chunk_attention_bias,
|
97 |
+
query_chunk_size, key_chunk_size, bias, deterministic,
|
98 |
+
attn_dropout, attn_pdrop, causal, dtype)
|
99 |
+
|
100 |
+
def scan_attention(args):
|
101 |
+
query_chunk, query_chunk_idx = args
|
102 |
+
|
103 |
+
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
|
104 |
+
def scan_kv_block(carry, args):
|
105 |
+
key_chunk, value_chunk, key_chunk_idx = args
|
106 |
+
(numerator, denominator, prev_max_score) = carry
|
107 |
+
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
|
108 |
+
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
|
109 |
+
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
|
110 |
+
attn_weights = attn_weights + bias_chunk
|
111 |
+
|
112 |
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
113 |
+
max_score = jnp.maximum(prev_max_score, max_score)
|
114 |
+
max_score = jax.lax.stop_gradient(max_score)
|
115 |
+
exp_weights = jnp.exp(attn_weights - max_score)
|
116 |
+
exp_values = jnp.einsum(
|
117 |
+
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
|
118 |
+
)
|
119 |
+
correction = jnp.exp(prev_max_score - max_score)
|
120 |
+
numerator = numerator * correction + exp_values
|
121 |
+
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
|
122 |
+
return Carry(numerator, denominator, max_score), None
|
123 |
+
|
124 |
+
def skip_upper_half(carry, args):
|
125 |
+
key_chunk, value_chunk, key_chunk_idx = args
|
126 |
+
skip_block = jnp.array(False)
|
127 |
+
if causal:
|
128 |
+
skip_block = query_chunk_idx < key_chunk_idx
|
129 |
+
return jax.lax.cond(
|
130 |
+
skip_block,
|
131 |
+
lambda carry, args: (carry, None),
|
132 |
+
scan_kv_block,
|
133 |
+
carry,
|
134 |
+
args,
|
135 |
+
)
|
136 |
+
|
137 |
+
init_carry = Carry(
|
138 |
+
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
139 |
+
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
140 |
+
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
|
141 |
+
)
|
142 |
+
(numerator, denominator, max_score), _ = lax.scan(
|
143 |
+
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
|
144 |
+
)
|
145 |
+
outputs = (numerator / denominator).astype(dtype)
|
146 |
+
return outputs
|
147 |
+
|
148 |
+
_, res = lax.scan(
|
149 |
+
lambda _, x: ((), scan_attention(x)),
|
150 |
+
(), xs=(query, jnp.arange(0, num_q))
|
151 |
+
)
|
152 |
+
res = rearrange(res, 'n b c h d -> b (n c) h d')
|
153 |
+
return res
|
154 |
+
|
155 |
+
|
156 |
+
class Carry(NamedTuple):
|
157 |
+
numerator: jax.Array
|
158 |
+
denominator: jax.Array
|
159 |
+
max_so_far: jax.Array
|
160 |
+
|
161 |
+
|
162 |
+
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
|
163 |
+
bias, deterministic, attn_dropout, attn_pdrop, causal,
|
164 |
+
dtype, query_chunk_idx, key_chunk_idx):
|
165 |
+
query_offset = query_chunk_idx * query_chunk_size
|
166 |
+
key_offset = key_chunk_idx * key_chunk_size
|
167 |
+
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
|
168 |
+
if bias is not None:
|
169 |
+
chunk_bias = lax.dynamic_slice(
|
170 |
+
bias,
|
171 |
+
start_indices=(0, 0, query_offset, key_offset),
|
172 |
+
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
|
173 |
+
)
|
174 |
+
|
175 |
+
if causal:
|
176 |
+
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
|
177 |
+
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
|
178 |
+
offset = query_offset - key_offset
|
179 |
+
query_idx += offset
|
180 |
+
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
|
181 |
+
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
|
182 |
+
|
183 |
+
if not deterministic and attn_pdrop > 0.0:
|
184 |
+
attn_dropout_slice = lax.dynamic_slice(
|
185 |
+
attn_dropout,
|
186 |
+
start_indices=(0, 0, query_offset, key_offset),
|
187 |
+
slice_sizes=(
|
188 |
+
*attn_dropout.shape[:2],
|
189 |
+
min(attn_dropout.shape[-2], query_chunk_size),
|
190 |
+
min(attn_dropout.shape[-1], key_chunk_size),
|
191 |
+
),
|
192 |
+
)
|
193 |
+
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
|
194 |
+
return chunk_bias.astype(dtype)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
# test
|
199 |
+
def reference_attn(query, key, value, causal, dtype):
|
200 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
201 |
+
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
202 |
+
if causal:
|
203 |
+
mask_value = jnp.finfo(logits.dtype).min
|
204 |
+
_, q_seq_len, _, _ = query.shape
|
205 |
+
_, kv_seq_len, _, _ = key.shape
|
206 |
+
mask_shape = (q_seq_len, kv_seq_len)
|
207 |
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
208 |
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
209 |
+
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
210 |
+
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
211 |
+
weights = jax.nn.softmax(logits, axis=-1)
|
212 |
+
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
213 |
+
return out
|
214 |
+
|
215 |
+
# random inputs
|
216 |
+
shape = (1, 32, 8, 64)
|
217 |
+
query = jax.random.normal(jax.random.PRNGKey(0), shape)
|
218 |
+
key = jax.random.normal(jax.random.PRNGKey(1), shape)
|
219 |
+
value = jax.random.normal(jax.random.PRNGKey(2), shape)
|
220 |
+
|
221 |
+
causal = True
|
222 |
+
chunk_size = 4
|
223 |
+
policy = jax.checkpoint_policies.nothing_saveable()
|
224 |
+
|
225 |
+
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
|
226 |
+
reference = reference_attn(query, key, value, causal, 'float32')
|
227 |
+
|
228 |
+
assert jnp.allclose(reference, blockwise, atol=1e-6)
|
EasyLM/checkpoint.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from ml_collections import ConfigDict
|
4 |
+
import mlxu
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import flax
|
8 |
+
from flax.serialization import (
|
9 |
+
from_bytes, to_bytes, to_state_dict, from_state_dict
|
10 |
+
)
|
11 |
+
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
|
12 |
+
import msgpack
|
13 |
+
|
14 |
+
from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
|
15 |
+
|
16 |
+
|
17 |
+
class StreamingCheckpointer(object):
|
18 |
+
""" Custom msgpack checkpointer that saves large train states by serializing
|
19 |
+
and saving tensors one by one in a streaming fashion. Avoids running
|
20 |
+
out of memory or local TPU disk with default flax checkpointer.
|
21 |
+
"""
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def get_default_config(updates=None):
|
25 |
+
config = ConfigDict()
|
26 |
+
config.float_dtype = 'bf16'
|
27 |
+
config.save_optimizer_state = False
|
28 |
+
|
29 |
+
if updates is not None:
|
30 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
31 |
+
return config
|
32 |
+
|
33 |
+
def __init__(self, config, checkpoint_dir, enable=True):
|
34 |
+
self.config = self.get_default_config(config)
|
35 |
+
self.checkpoint_dir = checkpoint_dir
|
36 |
+
self.enable = enable
|
37 |
+
|
38 |
+
def save_checkpoint(self, train_state, filename, gather_fns=None):
|
39 |
+
if self.enable:
|
40 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
41 |
+
else:
|
42 |
+
path = '/dev/null'
|
43 |
+
self.save_train_state_to_file(
|
44 |
+
train_state, path, gather_fns, self.config.float_dtype
|
45 |
+
)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
|
49 |
+
train_state = to_state_dict(train_state)
|
50 |
+
packer = msgpack.Packer()
|
51 |
+
flattend_train_state = flatten_dict(train_state)
|
52 |
+
if gather_fns is not None:
|
53 |
+
gather_fns = flatten_dict(to_state_dict(gather_fns))
|
54 |
+
|
55 |
+
with mlxu.open_file(path, "wb") as fout:
|
56 |
+
for key, value in flattend_train_state.items():
|
57 |
+
if gather_fns is not None:
|
58 |
+
value = gather_fns[key](value)
|
59 |
+
value = float_tensor_to_dtype(value, float_dtype)
|
60 |
+
fout.write(packer.pack((key, to_bytes(value))))
|
61 |
+
|
62 |
+
def save_pickle(self, obj, filename):
|
63 |
+
if self.enable:
|
64 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
65 |
+
else:
|
66 |
+
path = '/dev/null'
|
67 |
+
mlxu.save_pickle(obj, path)
|
68 |
+
|
69 |
+
def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
|
70 |
+
step = int(jax.device_get(train_state.step))
|
71 |
+
if self.config.save_optimizer_state:
|
72 |
+
checkpoint_state = train_state
|
73 |
+
checkpoint_name = 'streaming_train_state'
|
74 |
+
checkpoint_gather_fns = gather_fns
|
75 |
+
else:
|
76 |
+
checkpoint_state = train_state.params['params']
|
77 |
+
checkpoint_name = 'streaming_params'
|
78 |
+
checkpoint_gather_fns = gather_fns.params['params']
|
79 |
+
|
80 |
+
if milestone:
|
81 |
+
# Save a milestone checkpoint that will not be overwritten
|
82 |
+
self.save_pickle(metadata, f'metadata_{step}.pkl')
|
83 |
+
self.save_pickle(dataset, f'dataset_{step}.pkl')
|
84 |
+
self.save_checkpoint(
|
85 |
+
checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
# Save a normal checkpoint that can be overwritten
|
89 |
+
self.save_pickle(metadata, 'metadata.pkl')
|
90 |
+
self.save_pickle(dataset, 'dataset.pkl')
|
91 |
+
self.save_checkpoint(
|
92 |
+
checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
|
93 |
+
)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
|
97 |
+
if shard_fns is not None:
|
98 |
+
shard_fns = flatten_dict(
|
99 |
+
to_state_dict(shard_fns)
|
100 |
+
)
|
101 |
+
if remove_dict_prefix is not None:
|
102 |
+
remove_dict_prefix = tuple(remove_dict_prefix)
|
103 |
+
flattend_train_state = {}
|
104 |
+
with mlxu.open_file(path) as fin:
|
105 |
+
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
|
106 |
+
unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
|
107 |
+
for key, value in unpacker:
|
108 |
+
key = tuple(key)
|
109 |
+
if remove_dict_prefix is not None:
|
110 |
+
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
|
111 |
+
key = key[len(remove_dict_prefix):]
|
112 |
+
else:
|
113 |
+
continue
|
114 |
+
|
115 |
+
tensor = from_bytes(None, value)
|
116 |
+
if shard_fns is not None:
|
117 |
+
tensor = shard_fns[key](tensor)
|
118 |
+
flattend_train_state[key] = tensor
|
119 |
+
|
120 |
+
if target is not None:
|
121 |
+
flattened_target = flatten_dict(
|
122 |
+
to_state_dict(target), keep_empty_nodes=True
|
123 |
+
)
|
124 |
+
for key, value in flattened_target.items():
|
125 |
+
if key not in flattend_train_state and value == empty_node:
|
126 |
+
flattend_train_state[key] = value
|
127 |
+
|
128 |
+
train_state = unflatten_dict(flattend_train_state)
|
129 |
+
if target is None:
|
130 |
+
return train_state
|
131 |
+
|
132 |
+
return from_state_dict(target, train_state)
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def load_flax_checkpoint(path, target=None, shard_fns=None):
|
136 |
+
""" Load a standard flax checkpoint that's not saved with the
|
137 |
+
msgpack streaming format.
|
138 |
+
"""
|
139 |
+
with mlxu.open_file(path, "rb") as fin:
|
140 |
+
encoded_bytes = fin.read()
|
141 |
+
|
142 |
+
state_dict = flax.serialization.msgpack_restore(encoded_bytes)
|
143 |
+
if shard_fns is not None:
|
144 |
+
shard_fns = to_state_dict(shard_fns)
|
145 |
+
state_dict = tree_apply(shard_fns, state_dict)
|
146 |
+
|
147 |
+
if target is None:
|
148 |
+
return state_dict
|
149 |
+
return from_state_dict(target, state_dict)
|
150 |
+
|
151 |
+
@classmethod
|
152 |
+
def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
|
153 |
+
trainstate_shard_fns=None,
|
154 |
+
disallow_trainstate=False):
|
155 |
+
if trainstate_target is not None:
|
156 |
+
params_target = trainstate_target.params['params']
|
157 |
+
else:
|
158 |
+
params_target = None
|
159 |
+
|
160 |
+
if trainstate_shard_fns is not None:
|
161 |
+
params_shard_fns = trainstate_shard_fns.params['params']
|
162 |
+
else:
|
163 |
+
params_shard_fns = None
|
164 |
+
|
165 |
+
load_type, load_path = load_from.split('::', 1)
|
166 |
+
if disallow_trainstate:
|
167 |
+
assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
|
168 |
+
train_state = None
|
169 |
+
restored_params = None
|
170 |
+
if load_type == 'trainstate':
|
171 |
+
# Load the entire train state in the streaming format
|
172 |
+
train_state = cls.load_checkpoint(
|
173 |
+
path=load_path,
|
174 |
+
target=trainstate_target,
|
175 |
+
shard_fns=trainstate_shard_fns,
|
176 |
+
)
|
177 |
+
elif load_type == 'trainstate_params':
|
178 |
+
# Load the params part of the train state in the streaming format
|
179 |
+
restored_params = cls.load_checkpoint(
|
180 |
+
path=load_path,
|
181 |
+
target=params_target,
|
182 |
+
shard_fns=params_shard_fns,
|
183 |
+
remove_dict_prefix=('params', 'params'),
|
184 |
+
)
|
185 |
+
restored_params = flax.core.frozen_dict.freeze(
|
186 |
+
{'params': restored_params}
|
187 |
+
)
|
188 |
+
elif load_type == 'params':
|
189 |
+
# Load the params in the streaming format
|
190 |
+
restored_params = cls.load_checkpoint(
|
191 |
+
path=load_path,
|
192 |
+
target=params_target,
|
193 |
+
shard_fns=params_shard_fns,
|
194 |
+
)
|
195 |
+
restored_params = flax.core.frozen_dict.freeze(
|
196 |
+
{'params': restored_params}
|
197 |
+
)
|
198 |
+
elif load_type == 'flax_params':
|
199 |
+
# Load the params in the standard flax format (non-streaming)
|
200 |
+
# This requires the entire params to fit in memory
|
201 |
+
restored_params = cls.load_flax_checkpoint(
|
202 |
+
path=load_path,
|
203 |
+
target=params_target,
|
204 |
+
shard_fns=params_shard_fns
|
205 |
+
)
|
206 |
+
restored_params = flax.core.frozen_dict.freeze(
|
207 |
+
{'params': restored_params}
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
raise ValueError(f'Invalid load_from type: {load_type}')
|
211 |
+
|
212 |
+
return train_state, restored_params
|
EasyLM/data.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import pprint
|
3 |
+
import time
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import base64
|
7 |
+
from multiprocessing import Pool
|
8 |
+
|
9 |
+
import h5py
|
10 |
+
import mlxu
|
11 |
+
from ml_collections.config_dict import config_dict
|
12 |
+
from ml_collections import ConfigDict
|
13 |
+
from tqdm import tqdm, trange
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from datasets import load_dataset, load_from_disk
|
17 |
+
|
18 |
+
|
19 |
+
class DatasetFactory(object):
|
20 |
+
""" Datset builder class. """
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def get_default_config(updates=None):
|
24 |
+
config = ConfigDict()
|
25 |
+
config.type = 'huggingface'
|
26 |
+
config.text_processor = TextProcessor.get_default_config()
|
27 |
+
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
|
28 |
+
config.json_dataset = JsonDataset.get_default_config()
|
29 |
+
|
30 |
+
if updates is not None:
|
31 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
32 |
+
return config
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def load_dataset(cls, config, tokenizer, **kwargs):
|
36 |
+
config = cls.get_default_config(config)
|
37 |
+
text_processor = TextProcessor(config.text_processor, tokenizer)
|
38 |
+
if config.type == 'huggingface':
|
39 |
+
return HuggingfaceDataset(
|
40 |
+
config.huggingface_dataset, tokenizer, text_processor, **kwargs
|
41 |
+
)
|
42 |
+
elif config.type == 'json':
|
43 |
+
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
|
44 |
+
else:
|
45 |
+
raise ValueError(f'Unknown dataset type: {config.type}')
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
|
49 |
+
|
50 |
+
|
51 |
+
class TextProcessor(object):
|
52 |
+
""" Example processor that converts a dictionary of texts into tokens. """
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def get_default_config(updates=None):
|
56 |
+
config = ConfigDict()
|
57 |
+
config.fields_from_example = ''
|
58 |
+
config.fields = ''
|
59 |
+
config.subfield_separator = ' '
|
60 |
+
config.add_bos_token = True
|
61 |
+
config.add_eos_token = True
|
62 |
+
config.prepend_text = ''
|
63 |
+
config.base64_token_dtype = 'i4'
|
64 |
+
if updates is not None:
|
65 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
66 |
+
return config
|
67 |
+
|
68 |
+
def __init__(self, config, tokenizer):
|
69 |
+
self.config = self.get_default_config(config)
|
70 |
+
assert self.config.fields != '' or self.config.fields_from_example != '', (
|
71 |
+
'Either fields or fields_from_example must be specified.'
|
72 |
+
)
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
|
75 |
+
def __call__(self, example, has_aux=False):
|
76 |
+
if has_aux:
|
77 |
+
example, *aux = example
|
78 |
+
else:
|
79 |
+
aux = tuple()
|
80 |
+
token_buffer = []
|
81 |
+
loss_mask_buffer = []
|
82 |
+
|
83 |
+
if self.config.add_bos_token:
|
84 |
+
token_buffer.append(self.tokenizer.bos_token_id)
|
85 |
+
loss_mask_buffer.append(0.0)
|
86 |
+
|
87 |
+
if self.config.fields_from_example != '':
|
88 |
+
fields = example[self.config.fields_from_example].split(',')
|
89 |
+
else:
|
90 |
+
fields = self.config.fields.split(',')
|
91 |
+
|
92 |
+
for i, field in enumerate(fields):
|
93 |
+
if field.startswith('[') and field.endswith(']'):
|
94 |
+
# No loss for this field.
|
95 |
+
field = field[1:-1]
|
96 |
+
mask = 0.0
|
97 |
+
else:
|
98 |
+
mask = 1.0
|
99 |
+
|
100 |
+
if field.startswith('<|') and field.endswith('|>'):
|
101 |
+
# Special tokens.
|
102 |
+
field = field[2:-2]
|
103 |
+
if field == 'bos':
|
104 |
+
token_buffer.append(self.tokenizer.bos_token_id)
|
105 |
+
elif field == 'eos':
|
106 |
+
token_buffer.append(self.tokenizer.eos_token_id)
|
107 |
+
else:
|
108 |
+
# Token ID specified directly.
|
109 |
+
token_buffer.append(int(field))
|
110 |
+
loss_mask_buffer.append(mask)
|
111 |
+
elif field.startswith('{') and field.endswith('}'):
|
112 |
+
field = field[1:-1]
|
113 |
+
# Base64 encoded raw tokens.
|
114 |
+
tokens = np.frombuffer(
|
115 |
+
base64.b64decode(example[field]),
|
116 |
+
dtype=self.config.base64_token_dtype
|
117 |
+
).tolist()
|
118 |
+
token_buffer.extend(tokens)
|
119 |
+
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
120 |
+
else:
|
121 |
+
subfields = field.split('+')
|
122 |
+
text = self.config.subfield_separator.join(
|
123 |
+
[example[subfield] for subfield in subfields]
|
124 |
+
)
|
125 |
+
if i == 0:
|
126 |
+
text = self.config.prepend_text + text
|
127 |
+
tokens = self.tokenizer.encode(text)
|
128 |
+
token_buffer.extend(tokens)
|
129 |
+
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
130 |
+
|
131 |
+
if self.config.add_eos_token:
|
132 |
+
token_buffer.append(self.tokenizer.eos_token_id)
|
133 |
+
loss_mask_buffer.append(1.0)
|
134 |
+
|
135 |
+
return token_buffer, loss_mask_buffer, *aux
|
136 |
+
|
137 |
+
|
138 |
+
class HuggingfaceDataset(object):
|
139 |
+
""" Huggingface dataset, where the dataset is loaded using the huggingface
|
140 |
+
datasets.load_dataset() function.
|
141 |
+
"""
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def get_default_config(updates=None):
|
145 |
+
config = ConfigDict()
|
146 |
+
config.path = 'c4'
|
147 |
+
config.name = 'en'
|
148 |
+
config.split = 'train'
|
149 |
+
config.streaming = False
|
150 |
+
config.seq_length = 1024
|
151 |
+
config.batch_size = 8
|
152 |
+
config.always_start_with_bos = False
|
153 |
+
config.start_seek_loc = 0
|
154 |
+
config.tokens_count_at_start = 0
|
155 |
+
config.batch_token_dtype = 'i4'
|
156 |
+
|
157 |
+
if updates is not None:
|
158 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
159 |
+
return config
|
160 |
+
|
161 |
+
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
|
162 |
+
self.config = self.get_default_config(config)
|
163 |
+
name = self.config.name if self.config.name != '' else None
|
164 |
+
split = self.config.split if self.config.split != '' else None
|
165 |
+
self._tokenizer = tokenizer
|
166 |
+
self._text_processor = text_processor
|
167 |
+
self._dataset = load_from_disk(
|
168 |
+
self.config.path
|
169 |
+
)[split]
|
170 |
+
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
|
171 |
+
self._eval_dataset = eval_dataset
|
172 |
+
self._train_epochs = 0
|
173 |
+
self._dataset_loc = self.config.start_seek_loc
|
174 |
+
self._total_tokens = self.config.tokens_count_at_start
|
175 |
+
self._index = 0
|
176 |
+
|
177 |
+
def __iter__(self):
|
178 |
+
chunk_size = self.config.batch_size * self.config.seq_length
|
179 |
+
total_tokens = 0
|
180 |
+
while True:
|
181 |
+
token_buffer = []
|
182 |
+
loss_mask_buffer = []
|
183 |
+
if not self._eval_dataset:
|
184 |
+
self._shuffle()
|
185 |
+
for index, example in enumerate(self._dataset):
|
186 |
+
self._index = index
|
187 |
+
if not self._eval_dataset and self._dataset_loc > index:
|
188 |
+
continue
|
189 |
+
tokens, loss_masks = self.text_processor(example)
|
190 |
+
token_buffer.extend(tokens)
|
191 |
+
loss_mask_buffer.extend(loss_masks)
|
192 |
+
while len(token_buffer) > chunk_size + 1:
|
193 |
+
self._total_tokens += chunk_size
|
194 |
+
metrics = {
|
195 |
+
'dataset_example_index': index,
|
196 |
+
'dataset_total_tokens': self._total_tokens,
|
197 |
+
'epoch': self._train_epochs,
|
198 |
+
}
|
199 |
+
batch = {
|
200 |
+
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
201 |
+
self.config.batch_size, -1
|
202 |
+
),
|
203 |
+
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
204 |
+
self.config.batch_size, -1
|
205 |
+
),
|
206 |
+
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
207 |
+
self.config.batch_size, -1
|
208 |
+
),
|
209 |
+
}
|
210 |
+
if self.config.always_start_with_bos:
|
211 |
+
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
212 |
+
yield batch, metrics
|
213 |
+
token_buffer = token_buffer[chunk_size:]
|
214 |
+
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
215 |
+
|
216 |
+
if self._eval_dataset:
|
217 |
+
break
|
218 |
+
else:
|
219 |
+
self._dataset_loc = 0
|
220 |
+
self._shuffle()
|
221 |
+
self._train_epochs += 1
|
222 |
+
print(f"TRAIN {self._train_epochs} EPOCH DONE")
|
223 |
+
|
224 |
+
def _shuffle(self):
|
225 |
+
self._dataset = self._dataset.shuffle(buffer_size=100)
|
226 |
+
|
227 |
+
def get_state_dict(self):
|
228 |
+
return dict(
|
229 |
+
config=self.config,
|
230 |
+
dataset_loc=self._index,
|
231 |
+
total_tokens=self._total_tokens,
|
232 |
+
epochs=self._train_epochs,
|
233 |
+
)
|
234 |
+
|
235 |
+
def load_state_dict(self, state_dict):
|
236 |
+
if 'config' in state_dict:
|
237 |
+
self.config.update(ConfigDict(state_dict['config']))
|
238 |
+
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
239 |
+
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
240 |
+
self._train_epochs = state_dict.get('epochs', 0)
|
241 |
+
|
242 |
+
@property
|
243 |
+
def seq_length(self):
|
244 |
+
return self.config.seq_length
|
245 |
+
|
246 |
+
@property
|
247 |
+
def tokenizer(self):
|
248 |
+
return self._tokenizer
|
249 |
+
|
250 |
+
@property
|
251 |
+
def text_processor(self):
|
252 |
+
return self._text_processor
|
253 |
+
|
254 |
+
@property
|
255 |
+
def dataset(self):
|
256 |
+
return self._dataset
|
257 |
+
|
258 |
+
@property
|
259 |
+
def vocab_size(self):
|
260 |
+
return len(self._tokenizer)
|
261 |
+
|
262 |
+
|
263 |
+
class JsonDataset(object):
|
264 |
+
""" JSON dataset, where each line of the data file contains a JSON
|
265 |
+
dictionary with text fields.
|
266 |
+
"""
|
267 |
+
|
268 |
+
@staticmethod
|
269 |
+
def get_default_config(updates=None):
|
270 |
+
config = ConfigDict()
|
271 |
+
config.path = ''
|
272 |
+
config.seq_length = 1024
|
273 |
+
config.batch_size = 8
|
274 |
+
config.always_start_with_bos = False
|
275 |
+
config.start_seek_loc = 0
|
276 |
+
config.example_index_at_start = 0
|
277 |
+
config.tokens_count_at_start = 0
|
278 |
+
config.tokenizer_processes = 1
|
279 |
+
config.tokenizer_parallel_chunk_size = 32
|
280 |
+
config.tokenizer_parallel_batch_size = 1024
|
281 |
+
config.throughput_average_window_size = 200
|
282 |
+
|
283 |
+
if updates is not None:
|
284 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
285 |
+
return config
|
286 |
+
|
287 |
+
def __init__(self, config, tokenizer, text_processor):
|
288 |
+
self.config = self.get_default_config(config)
|
289 |
+
assert self.config.path != ''
|
290 |
+
self._tokenizer = tokenizer
|
291 |
+
self._text_processor = text_processor
|
292 |
+
self._index = self.config.example_index_at_start
|
293 |
+
self._file_loc = self.config.start_seek_loc
|
294 |
+
self._total_tokens = self.config.tokens_count_at_start
|
295 |
+
|
296 |
+
def parse_json(self, line):
|
297 |
+
if not line or line == '\n':
|
298 |
+
return None
|
299 |
+
try:
|
300 |
+
data = json.loads(line)
|
301 |
+
except json.decoder.JSONDecodeError:
|
302 |
+
print(f'Error parsing json line:\n{line}')
|
303 |
+
return None
|
304 |
+
return data
|
305 |
+
|
306 |
+
def json_iterator(self):
|
307 |
+
with mlxu.open_file(self.config.path, 'r') as fin:
|
308 |
+
fin.seek(self._file_loc)
|
309 |
+
while True:
|
310 |
+
line = fin.readline()
|
311 |
+
self._file_loc = fin.tell()
|
312 |
+
if not line: # Reached EOF
|
313 |
+
self._index = 0
|
314 |
+
fin.seek(0)
|
315 |
+
continue
|
316 |
+
|
317 |
+
data = self.parse_json(line)
|
318 |
+
if data is not None:
|
319 |
+
# JSON parsing succeeded
|
320 |
+
yield data, self._file_loc, self._index
|
321 |
+
self._index += 1
|
322 |
+
|
323 |
+
def batched(self, iterator, batch_size):
|
324 |
+
batch = []
|
325 |
+
for example in iterator:
|
326 |
+
batch.append(example)
|
327 |
+
if len(batch) == batch_size:
|
328 |
+
yield batch
|
329 |
+
batch = []
|
330 |
+
if len(batch) > 0:
|
331 |
+
yield batch
|
332 |
+
|
333 |
+
def parallel_example_iterator(self):
|
334 |
+
if self.config.tokenizer_processes == 1:
|
335 |
+
for example, loc, index in self.json_iterator():
|
336 |
+
yield self.text_processor((example, loc, index), has_aux=True)
|
337 |
+
else:
|
338 |
+
process_pool = Pool(self.config.tokenizer_processes)
|
339 |
+
batched_iterator = self.batched(
|
340 |
+
self.json_iterator(), self.config.tokenizer_parallel_batch_size
|
341 |
+
)
|
342 |
+
with process_pool as pool:
|
343 |
+
map_fn = partial(self.text_processor, has_aux=True)
|
344 |
+
next_batch = pool.map_async(
|
345 |
+
map_fn, next(batched_iterator),
|
346 |
+
chunksize=self.config.tokenizer_parallel_chunk_size
|
347 |
+
)
|
348 |
+
while True:
|
349 |
+
current_batch = next_batch
|
350 |
+
next_batch = pool.map_async(
|
351 |
+
map_fn, next(batched_iterator),
|
352 |
+
chunksize=self.config.tokenizer_parallel_chunk_size
|
353 |
+
)
|
354 |
+
for example in current_batch.get():
|
355 |
+
yield example
|
356 |
+
|
357 |
+
def __iter__(self):
|
358 |
+
chunk_size = self.config.batch_size * self.config.seq_length
|
359 |
+
token_buffer = []
|
360 |
+
loss_mask_buffer = []
|
361 |
+
last_time = 0.0
|
362 |
+
step_times = []
|
363 |
+
start_time = time.time()
|
364 |
+
start_tokens = self._total_tokens
|
365 |
+
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
|
366 |
+
token_buffer.extend(tokens)
|
367 |
+
loss_mask_buffer.extend(loss_masks)
|
368 |
+
while len(token_buffer) > chunk_size + 1:
|
369 |
+
self._total_tokens += chunk_size
|
370 |
+
step_times.append(time.time() - last_time)
|
371 |
+
last_time = time.time()
|
372 |
+
if len(step_times) > self.config.throughput_average_window_size:
|
373 |
+
step_times = step_times[-self.config.throughput_average_window_size:]
|
374 |
+
average_throughput = chunk_size / np.mean(step_times)
|
375 |
+
accumulated_throughput = (
|
376 |
+
(self._total_tokens - start_tokens) / (time.time() - start_time)
|
377 |
+
)
|
378 |
+
metrics = {
|
379 |
+
'dataset_file_loc': loc,
|
380 |
+
'dataset_example_index': index,
|
381 |
+
'dataset_total_tokens': self._total_tokens,
|
382 |
+
'dataset_accumulated_tps': accumulated_throughput,
|
383 |
+
'dataset_average_tps': average_throughput,
|
384 |
+
}
|
385 |
+
batch = {
|
386 |
+
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
|
387 |
+
self.config.batch_size, -1
|
388 |
+
),
|
389 |
+
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
|
390 |
+
self.config.batch_size, -1
|
391 |
+
),
|
392 |
+
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
393 |
+
self.config.batch_size, -1
|
394 |
+
),
|
395 |
+
}
|
396 |
+
if self.config.always_start_with_bos:
|
397 |
+
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
398 |
+
yield batch, metrics
|
399 |
+
token_buffer = token_buffer[chunk_size:]
|
400 |
+
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
401 |
+
|
402 |
+
def get_state_dict(self):
|
403 |
+
return dict(
|
404 |
+
config=self.config,
|
405 |
+
index=self._index,
|
406 |
+
file_loc=self._file_loc,
|
407 |
+
total_tokens=self._total_tokens,
|
408 |
+
)
|
409 |
+
|
410 |
+
def load_state_dict(self, state_dict):
|
411 |
+
if 'config' in state_dict:
|
412 |
+
self.config.update(ConfigDict(state_dict['config']))
|
413 |
+
self._index = state_dict.get('index', self.config.example_index_at_start)
|
414 |
+
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
|
415 |
+
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
416 |
+
|
417 |
+
@property
|
418 |
+
def seq_length(self):
|
419 |
+
return self.config.seq_length
|
420 |
+
|
421 |
+
@property
|
422 |
+
def tokenizer(self):
|
423 |
+
return self._tokenizer
|
424 |
+
|
425 |
+
@property
|
426 |
+
def text_processor(self):
|
427 |
+
return self._text_processor
|
428 |
+
|
429 |
+
@property
|
430 |
+
def vocab_size(self):
|
431 |
+
return len(self.tokenizer)
|
EasyLM/jax_utils.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
4 |
+
from functools import partial
|
5 |
+
import re
|
6 |
+
import dataclasses
|
7 |
+
import random
|
8 |
+
from ml_collections import ConfigDict
|
9 |
+
from ml_collections.config_dict.config_dict import placeholder
|
10 |
+
|
11 |
+
import flax
|
12 |
+
import jax
|
13 |
+
import jax.numpy as jnp
|
14 |
+
from jax.sharding import PartitionSpec as PS
|
15 |
+
from jax.sharding import Mesh
|
16 |
+
from jax.experimental import mesh_utils
|
17 |
+
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
|
18 |
+
from jax.experimental.pjit import pjit
|
19 |
+
from jax.interpreters import pxla
|
20 |
+
import numpy as np
|
21 |
+
from transformers import FlaxLogitsWarper
|
22 |
+
|
23 |
+
|
24 |
+
class JaxRNG(object):
|
25 |
+
""" A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
|
26 |
+
pure function.
|
27 |
+
"""
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_seed(cls, seed):
|
31 |
+
return cls(jax.random.PRNGKey(seed))
|
32 |
+
|
33 |
+
def __init__(self, rng):
|
34 |
+
self.rng = rng
|
35 |
+
|
36 |
+
def __call__(self, keys=None):
|
37 |
+
if keys is None:
|
38 |
+
self.rng, split_rng = jax.random.split(self.rng)
|
39 |
+
return split_rng
|
40 |
+
elif isinstance(keys, int):
|
41 |
+
split_rngs = jax.random.split(self.rng, num=keys + 1)
|
42 |
+
self.rng = split_rngs[0]
|
43 |
+
return tuple(split_rngs[1:])
|
44 |
+
else:
|
45 |
+
split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
|
46 |
+
self.rng = split_rngs[0]
|
47 |
+
return {key: val for key, val in zip(keys, split_rngs[1:])}
|
48 |
+
|
49 |
+
|
50 |
+
class JaxDistributedConfig(object):
|
51 |
+
""" Utility class for initializing JAX distributed. """
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def get_default_config(updates=None):
|
55 |
+
config = ConfigDict()
|
56 |
+
config.initialize_jax_distributed = False
|
57 |
+
config.coordinator_address = placeholder(str)
|
58 |
+
config.num_processes = placeholder(int)
|
59 |
+
config.process_id = placeholder(int)
|
60 |
+
config.local_device_ids = placeholder(str)
|
61 |
+
|
62 |
+
if updates is not None:
|
63 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
64 |
+
return config
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def initialize(cls, config):
|
68 |
+
config = cls.get_default_config(config)
|
69 |
+
if config.initialize_jax_distributed:
|
70 |
+
if config.local_device_ids is not None:
|
71 |
+
local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
|
72 |
+
else:
|
73 |
+
local_device_ids = None
|
74 |
+
|
75 |
+
jax.distributed.initialize(
|
76 |
+
coordinator_address=config.coordinator_address,
|
77 |
+
num_processes=config.num_processes,
|
78 |
+
process_id=config.process_id,
|
79 |
+
local_device_ids=local_device_ids,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
84 |
+
""" JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
|
85 |
+
def __init__(self, temperature):
|
86 |
+
self.temperature = temperature
|
87 |
+
|
88 |
+
def __call__(self, input_ids, scores, cur_len):
|
89 |
+
return scores / jnp.clip(self.temperature, a_min=1e-8)
|
90 |
+
|
91 |
+
|
92 |
+
def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
|
93 |
+
""" Create pytree of sharding and gathering functions from pytree of
|
94 |
+
partition specs.
|
95 |
+
"""
|
96 |
+
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
97 |
+
|
98 |
+
def make_to_dtype_fn(dtype_spec):
|
99 |
+
def to_dtype(tensor):
|
100 |
+
if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
|
101 |
+
# Convert all float tensors to the same dtype
|
102 |
+
return tensor.astype(dtype_specs)
|
103 |
+
elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
|
104 |
+
return tensor.astype(dtype_spec.dtype)
|
105 |
+
return tensor
|
106 |
+
return to_dtype
|
107 |
+
|
108 |
+
def make_shard_fn(partition_spec, dtype_spec=None):
|
109 |
+
jax_shard_function = pjit(
|
110 |
+
make_to_dtype_fn(dtype_spec),
|
111 |
+
in_shardings=None,
|
112 |
+
out_shardings=partition_spec
|
113 |
+
)
|
114 |
+
def shard_fn(tensor):
|
115 |
+
return jax_shard_function(tensor).block_until_ready()
|
116 |
+
return shard_fn
|
117 |
+
|
118 |
+
def make_gather_fn(partition_spec, dtype_spec=None):
|
119 |
+
jax_gather_fn = pjit(
|
120 |
+
make_to_dtype_fn(dtype_spec),
|
121 |
+
in_shardings=partition_spec,
|
122 |
+
out_shardings=None
|
123 |
+
)
|
124 |
+
def gather_fn(tensor):
|
125 |
+
return jax.device_get(jax_gather_fn(tensor))
|
126 |
+
return gather_fn
|
127 |
+
|
128 |
+
if dtype_specs is None or dtype_specs in float_dtypes:
|
129 |
+
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
|
130 |
+
gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
|
131 |
+
else:
|
132 |
+
shard_fns = jax.tree_util.tree_map(
|
133 |
+
make_shard_fn, partition_specs, dtype_specs
|
134 |
+
)
|
135 |
+
gather_fns = jax.tree_util.tree_map(
|
136 |
+
make_gather_fn, partition_specs, dtype_specs
|
137 |
+
)
|
138 |
+
return shard_fns, gather_fns
|
139 |
+
|
140 |
+
|
141 |
+
def set_random_seed(seed):
|
142 |
+
np.random.seed(seed)
|
143 |
+
random.seed(seed)
|
144 |
+
init_rng(seed)
|
145 |
+
|
146 |
+
|
147 |
+
def get_jax_mesh(axis_dims, names):
|
148 |
+
if axis_dims.startswith('!'):
|
149 |
+
# Allow splitting a physical mesh axis if needed
|
150 |
+
mesh_axis_splitting = True
|
151 |
+
axis_dims = axis_dims[1:]
|
152 |
+
else:
|
153 |
+
mesh_axis_splitting = False
|
154 |
+
|
155 |
+
if ':' in axis_dims:
|
156 |
+
dims = []
|
157 |
+
dim_names = []
|
158 |
+
for axis in axis_dims.split(','):
|
159 |
+
name, dim = axis.split(':')
|
160 |
+
assert name in names
|
161 |
+
dims.append(int(dim))
|
162 |
+
dim_names.append(name)
|
163 |
+
assert(set(dim_names) == set(names))
|
164 |
+
else:
|
165 |
+
dims = [int(x) for x in axis_dims.split(',')]
|
166 |
+
dim_names = names
|
167 |
+
assert len(dims) == len(names)
|
168 |
+
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
|
169 |
+
if mesh_axis_splitting:
|
170 |
+
physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
|
171 |
+
else:
|
172 |
+
physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
|
173 |
+
return Mesh(physical_mesh, dim_names)
|
174 |
+
|
175 |
+
|
176 |
+
def names_in_current_mesh(*names):
|
177 |
+
""" Check if current mesh axes contain these names. """
|
178 |
+
mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
|
179 |
+
return set(names) <= set(mesh_axis_names)
|
180 |
+
|
181 |
+
|
182 |
+
def get_names_from_parition_spec(partition_specs):
|
183 |
+
""" Return axis names from partition specs. """
|
184 |
+
names = set()
|
185 |
+
if isinstance(partition_specs, dict):
|
186 |
+
partition_specs = partition_specs.values()
|
187 |
+
for item in partition_specs:
|
188 |
+
if item is None:
|
189 |
+
continue
|
190 |
+
elif isinstance(item, str):
|
191 |
+
names.add(item)
|
192 |
+
else:
|
193 |
+
names.update(get_names_from_parition_spec(item))
|
194 |
+
|
195 |
+
return list(names)
|
196 |
+
|
197 |
+
|
198 |
+
def with_sharding_constraint(x, partition_specs):
|
199 |
+
""" A smarter version of with_sharding_constraint that only applies the
|
200 |
+
constraint if the current mesh contains the axes in the partition specs.
|
201 |
+
"""
|
202 |
+
axis_names = get_names_from_parition_spec(partition_specs)
|
203 |
+
if names_in_current_mesh(*axis_names):
|
204 |
+
x = _with_sharding_constraint(x, partition_specs)
|
205 |
+
return x
|
206 |
+
|
207 |
+
|
208 |
+
def wrap_function_with_rng(rng):
|
209 |
+
""" To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
|
210 |
+
def wrap_function(function):
|
211 |
+
def wrapped(*args, **kwargs):
|
212 |
+
nonlocal rng
|
213 |
+
rng, split_rng = jax.random.split(rng)
|
214 |
+
return function(split_rng, *args, **kwargs)
|
215 |
+
return wrapped
|
216 |
+
return wrap_function
|
217 |
+
|
218 |
+
|
219 |
+
def init_rng(seed):
|
220 |
+
global jax_utils_rng
|
221 |
+
jax_utils_rng = JaxRNG.from_seed(seed)
|
222 |
+
|
223 |
+
|
224 |
+
def next_rng(*args, **kwargs):
|
225 |
+
global jax_utils_rng
|
226 |
+
return jax_utils_rng(*args, **kwargs)
|
227 |
+
|
228 |
+
|
229 |
+
def get_metrics(metrics, unreplicate=False, stack=False):
|
230 |
+
if unreplicate:
|
231 |
+
metrics = flax.jax_utils.unreplicate(metrics)
|
232 |
+
metrics = jax.device_get(metrics)
|
233 |
+
if stack:
|
234 |
+
return jax.tree_map(lambda *args: np.stack(args), *metrics)
|
235 |
+
else:
|
236 |
+
return {key: float(val) for key, val in metrics.items()}
|
237 |
+
|
238 |
+
|
239 |
+
def mse_loss(val, target, valid=None):
|
240 |
+
if valid is None:
|
241 |
+
valid = jnp.ones((*target.shape[:2], 1))
|
242 |
+
valid = valid.astype(jnp.float32)
|
243 |
+
loss = jnp.mean(
|
244 |
+
jnp.where(
|
245 |
+
valid > 0.0,
|
246 |
+
jnp.square(val - target),
|
247 |
+
0.0
|
248 |
+
)
|
249 |
+
)
|
250 |
+
return loss
|
251 |
+
|
252 |
+
|
253 |
+
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
|
254 |
+
if valid is None:
|
255 |
+
valid = jnp.ones(tokens.shape[:2])
|
256 |
+
valid = valid.astype(jnp.float32)
|
257 |
+
valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
|
258 |
+
logits = logits.astype(jnp.float32) # for numerical stability
|
259 |
+
token_log_prob = jnp.squeeze(
|
260 |
+
jnp.take_along_axis(
|
261 |
+
jax.nn.log_softmax(logits, axis=-1),
|
262 |
+
jnp.expand_dims(tokens, -1),
|
263 |
+
axis=-1,
|
264 |
+
),
|
265 |
+
-1,
|
266 |
+
)
|
267 |
+
token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
|
268 |
+
loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
|
269 |
+
correct = jnp.where(
|
270 |
+
valid > 0.0,
|
271 |
+
jnp.argmax(logits, axis=-1) == tokens,
|
272 |
+
jnp.array(False)
|
273 |
+
)
|
274 |
+
accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
|
275 |
+
return loss, accuracy
|
276 |
+
|
277 |
+
|
278 |
+
def global_norm(tree):
|
279 |
+
""" Return the global L2 norm of a pytree. """
|
280 |
+
squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
|
281 |
+
flattened, _ = jax.flatten_util.ravel_pytree(squared)
|
282 |
+
return jnp.sqrt(jnp.sum(flattened))
|
283 |
+
|
284 |
+
|
285 |
+
def average_metrics(metrics):
|
286 |
+
with jax.spmd_mode("allow_all"):
|
287 |
+
return jax.tree_map(
|
288 |
+
lambda *args: jnp.mean(jnp.stack(args)),
|
289 |
+
*metrics
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
def get_float_dtype_by_name(dtype):
|
294 |
+
return {
|
295 |
+
'bf16': jnp.bfloat16,
|
296 |
+
'bfloat16': jnp.bfloat16,
|
297 |
+
'fp16': jnp.float16,
|
298 |
+
'float16': jnp.float16,
|
299 |
+
'fp32': jnp.float32,
|
300 |
+
'float32': jnp.float32,
|
301 |
+
'fp64': jnp.float64,
|
302 |
+
'float64': jnp.float64,
|
303 |
+
}[dtype]
|
304 |
+
|
305 |
+
|
306 |
+
def float_tensor_to_dtype(tensor, dtype):
|
307 |
+
if dtype is None or dtype == '':
|
308 |
+
return tensor
|
309 |
+
if isinstance(dtype, str):
|
310 |
+
dtype = get_float_dtype_by_name(dtype)
|
311 |
+
float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
|
312 |
+
if getattr(tensor, 'dtype', None) in float_dtypes:
|
313 |
+
tensor = tensor.astype(dtype)
|
314 |
+
return tensor
|
315 |
+
|
316 |
+
|
317 |
+
def float_to_dtype(tree, dtype):
|
318 |
+
return jax.tree_util.tree_map(
|
319 |
+
partial(float_tensor_to_dtype, dtype=dtype), tree
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
def get_gradient_checkpoint_policy(name):
|
324 |
+
return {
|
325 |
+
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
326 |
+
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
327 |
+
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
328 |
+
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
329 |
+
}[name]
|
330 |
+
|
331 |
+
|
332 |
+
def tree_path_to_string(path, sep=None):
|
333 |
+
keys = []
|
334 |
+
for key in path:
|
335 |
+
if isinstance(key, jax.tree_util.SequenceKey):
|
336 |
+
keys.append(str(key.idx))
|
337 |
+
elif isinstance(key, jax.tree_util.DictKey):
|
338 |
+
keys.append(str(key.key))
|
339 |
+
elif isinstance(key, jax.tree_util.GetAttrKey):
|
340 |
+
keys.append(str(key.name))
|
341 |
+
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
|
342 |
+
keys.append(str(key.key))
|
343 |
+
else:
|
344 |
+
keys.append(str(key))
|
345 |
+
if sep is None:
|
346 |
+
return tuple(keys)
|
347 |
+
return sep.join(keys)
|
348 |
+
|
349 |
+
|
350 |
+
def flatten_tree(xs, is_leaf=None, sep=None):
|
351 |
+
flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
|
352 |
+
output = {}
|
353 |
+
for key, val in flattened:
|
354 |
+
output[tree_path_to_string(key, sep=sep)] = val
|
355 |
+
return output
|
356 |
+
|
357 |
+
|
358 |
+
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
|
359 |
+
""" An extended version of jax.tree_util.tree_map, where the mapped function
|
360 |
+
f takes both the name (path) and the tree leaf as input.
|
361 |
+
"""
|
362 |
+
return jax.tree_util.tree_map_with_path(
|
363 |
+
lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
|
364 |
+
tree, *rest,
|
365 |
+
is_leaf=is_leaf
|
366 |
+
)
|
367 |
+
|
368 |
+
|
369 |
+
def match_partition_rules(rules, params):
|
370 |
+
""" Returns a pytree of PartitionSpec according to rules. Supports handling
|
371 |
+
Flax TrainState and Optax optimizer state.
|
372 |
+
"""
|
373 |
+
def get_partition_spec(name, leaf):
|
374 |
+
if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
|
375 |
+
""" Don't partition scalar values. """
|
376 |
+
return PS()
|
377 |
+
for rule, ps in rules:
|
378 |
+
if re.search(rule, name) is not None:
|
379 |
+
return ps
|
380 |
+
raise ValueError(f'Partition rule not found for param: {name}')
|
381 |
+
return named_tree_map(get_partition_spec, params, sep='/')
|
382 |
+
|
383 |
+
|
384 |
+
def get_weight_decay_mask(exclusions):
|
385 |
+
""" Return a weight decay mask function that computes the pytree masks
|
386 |
+
according to the given exclusion rules.
|
387 |
+
"""
|
388 |
+
def decay(name, _):
|
389 |
+
for rule in exclusions:
|
390 |
+
if re.search(rule, name) is not None:
|
391 |
+
return False
|
392 |
+
return True
|
393 |
+
|
394 |
+
def weight_decay_mask(params):
|
395 |
+
return named_tree_map(decay, params, sep='/')
|
396 |
+
|
397 |
+
return weight_decay_mask
|
398 |
+
|
399 |
+
|
400 |
+
def tree_apply(fns, tree):
|
401 |
+
""" Apply a pytree of functions to the pytree. """
|
402 |
+
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
403 |
+
|
EasyLM/models/__init__.py
ADDED
File without changes
|
EasyLM/models/gptj/__init__.py
ADDED
File without changes
|
EasyLM/models/gptj/gptj_model.py
ADDED
@@ -0,0 +1,1054 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The EleutherAI and The HuggingFace Inc. team.
|
3 |
+
# Modifications copyright 2022 Xinyang Geng
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from functools import partial
|
19 |
+
from typing import Optional, Tuple
|
20 |
+
import json
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import flax.linen as nn
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
28 |
+
from flax.linen import combine_masks, make_causal_mask
|
29 |
+
from flax.linen.attention import dot_product_attention_weights
|
30 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
31 |
+
from jax import lax
|
32 |
+
from flax.linen import partitioning as nn_partitioning
|
33 |
+
|
34 |
+
from transformers.configuration_utils import PretrainedConfig
|
35 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
36 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
37 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
38 |
+
from transformers.generation.flax_logits_process import FlaxLogitsProcessorList
|
39 |
+
from transformers import AutoTokenizer
|
40 |
+
from jax.sharding import PartitionSpec
|
41 |
+
|
42 |
+
from ml_collections import ConfigDict
|
43 |
+
from ml_collections.config_dict import config_dict
|
44 |
+
from mlxu import function_args_to_config, load_pickle, open_file
|
45 |
+
|
46 |
+
from EasyLM.jax_utils import (
|
47 |
+
with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
"""
|
52 |
+
The follow code is taken from
|
53 |
+
transformers/src/transformers/models/gptj/configuration_gptj.py
|
54 |
+
and modified to work with EasyLM.
|
55 |
+
"""
|
56 |
+
|
57 |
+
|
58 |
+
GPTJ_STANDARD_CONFIGS = {
|
59 |
+
'6b': {
|
60 |
+
"vocab_size": 50400,
|
61 |
+
"n_positions": 2048,
|
62 |
+
"n_embd": 4096,
|
63 |
+
"n_layer": 28,
|
64 |
+
"n_head": 16,
|
65 |
+
"rotary_dim": 64,
|
66 |
+
"n_inner": None,
|
67 |
+
"activation_function": "gelu_new",
|
68 |
+
"layer_norm_epsilon": 1e-5,
|
69 |
+
"initializer_range": 0.02,
|
70 |
+
"scale_attn_weights": True,
|
71 |
+
"use_cache": True,
|
72 |
+
"bos_token_id": 50256,
|
73 |
+
"eos_token_id": 50256,
|
74 |
+
"tie_word_embeddings": False,
|
75 |
+
"n_real_tokens": 50257,
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
class GPTJConfig(PretrainedConfig):
|
81 |
+
r"""
|
82 |
+
This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
|
83 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
84 |
+
defaults will yield a similar configuration to that of the GPT-J
|
85 |
+
[EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
|
86 |
+
[`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
87 |
+
for more information.
|
88 |
+
Args:
|
89 |
+
vocab_size (`int`, *optional*, defaults to 50400):
|
90 |
+
Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
|
91 |
+
`inputs_ids` passed when calling [`GPTJModel`].
|
92 |
+
n_positions (`int`, *optional*, defaults to 2048):
|
93 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
94 |
+
just in case (e.g., 512 or 1024 or 2048).
|
95 |
+
n_embd (`int`, *optional*, defaults to 4096):
|
96 |
+
Dimensionality of the embeddings and hidden states.
|
97 |
+
n_layer (`int`, *optional*, defaults to 28):
|
98 |
+
Number of hidden layers in the Transformer encoder.
|
99 |
+
n_head (`int`, *optional*, defaults to 16):
|
100 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
101 |
+
rotary_dim (`int`, *optional*, defaults to 64):
|
102 |
+
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
|
103 |
+
n_inner (`int`, *optional*, defaults to 0):
|
104 |
+
Dimensionality of the inner feed-forward layers. 0 will set it to 4 times n_embd
|
105 |
+
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
106 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
107 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
108 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
109 |
+
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
110 |
+
The dropout ratio for the embeddings.
|
111 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
112 |
+
The dropout ratio for the attention.
|
113 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
114 |
+
The epsilon to use in the layer normalization layers.
|
115 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
116 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
117 |
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
118 |
+
Scale attention weights by dividing by sqrt(hidden_size).
|
119 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
120 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
121 |
+
Example:
|
122 |
+
```python
|
123 |
+
>>> from transformers import GPTJModel, GPTJConfig
|
124 |
+
>>> # Initializing a GPT-J 6B configuration
|
125 |
+
>>> configuration = GPTJConfig()
|
126 |
+
>>> # Initializing a model from the configuration
|
127 |
+
>>> model = GPTJModel(configuration)
|
128 |
+
>>> # Accessing the model configuration
|
129 |
+
>>> configuration = model.config
|
130 |
+
```"""
|
131 |
+
model_type = "gptj"
|
132 |
+
attribute_map = {
|
133 |
+
"max_position_embeddings": "n_positions",
|
134 |
+
"hidden_size": "n_embd",
|
135 |
+
"num_attention_heads": "n_head",
|
136 |
+
"num_hidden_layers": "n_layer",
|
137 |
+
}
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
vocab_size=50400,
|
142 |
+
n_positions=2048,
|
143 |
+
n_embd=4096,
|
144 |
+
n_layer=28,
|
145 |
+
n_head=16,
|
146 |
+
rotary_dim=64,
|
147 |
+
n_inner=None,
|
148 |
+
activation_function="gelu_new",
|
149 |
+
resid_pdrop=0.0,
|
150 |
+
embd_pdrop=0.0,
|
151 |
+
attn_pdrop=0.0,
|
152 |
+
layer_norm_epsilon=1e-5,
|
153 |
+
initializer_range=0.02,
|
154 |
+
scale_attn_weights=True,
|
155 |
+
use_cache=True,
|
156 |
+
bos_token_id=50256,
|
157 |
+
eos_token_id=50256,
|
158 |
+
tie_word_embeddings=False,
|
159 |
+
gradient_checkpointing=True,
|
160 |
+
gradient_checkpointing_policy='nothing_saveable',
|
161 |
+
n_real_tokens=50257,
|
162 |
+
fcm_min_ratio=0.0,
|
163 |
+
fcm_max_ratio=0.0,
|
164 |
+
**kwargs
|
165 |
+
):
|
166 |
+
self.vocab_size = vocab_size
|
167 |
+
self.n_positions = n_positions
|
168 |
+
self.n_embd = n_embd
|
169 |
+
self.n_layer = n_layer
|
170 |
+
self.n_head = n_head
|
171 |
+
self.n_inner = n_inner
|
172 |
+
self.rotary_dim = rotary_dim
|
173 |
+
self.activation_function = activation_function
|
174 |
+
self.resid_pdrop = resid_pdrop
|
175 |
+
self.embd_pdrop = embd_pdrop
|
176 |
+
self.attn_pdrop = attn_pdrop
|
177 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
178 |
+
self.initializer_range = initializer_range
|
179 |
+
self.scale_attn_weights = scale_attn_weights
|
180 |
+
self.use_cache = use_cache
|
181 |
+
self.gradient_checkpointing = gradient_checkpointing
|
182 |
+
self.gradient_checkpointing_policy = gradient_checkpointing_policy
|
183 |
+
self.n_real_tokens = n_real_tokens
|
184 |
+
self.fcm_min_ratio = fcm_min_ratio
|
185 |
+
self.fcm_max_ratio = fcm_max_ratio
|
186 |
+
if self.n_real_tokens is None:
|
187 |
+
self.n_real_tokens = self.vocab_size
|
188 |
+
|
189 |
+
self.bos_token_id = bos_token_id
|
190 |
+
self.eos_token_id = eos_token_id
|
191 |
+
|
192 |
+
super().__init__(
|
193 |
+
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
194 |
+
)
|
195 |
+
|
196 |
+
@classmethod
|
197 |
+
def get_default_config(cls, updates=None):
|
198 |
+
none_arg_types = dict(
|
199 |
+
n_inner=int,
|
200 |
+
rotary_dim=int,
|
201 |
+
)
|
202 |
+
config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
|
203 |
+
|
204 |
+
if updates is not None:
|
205 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
206 |
+
|
207 |
+
return config
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def get_jax_mesh(axis_dims):
|
211 |
+
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def get_partition_rules():
|
215 |
+
""" Parition rules for GPTJ. Note that these rules are orderd, so that
|
216 |
+
the beginning rules match first. It is important to use
|
217 |
+
PartitionSpec() instead of None here because JAX does not treat
|
218 |
+
None as a pytree leaf.
|
219 |
+
"""
|
220 |
+
return (
|
221 |
+
('transformer/wte/embedding', PartitionSpec('mp', 'fsdp')),
|
222 |
+
('attn/(k_proj|q_proj|v_proj)/kernel', PartitionSpec('fsdp', 'mp')),
|
223 |
+
('attn/out_proj/kernel', PartitionSpec('mp', 'fsdp')),
|
224 |
+
('mlp/fc_in/kernel', PartitionSpec('fsdp', 'mp')),
|
225 |
+
('mlp/fc_in/bias', PartitionSpec('mp')),
|
226 |
+
('mlp/fc_out/kernel', PartitionSpec('mp', 'fsdp')),
|
227 |
+
('mlp/fc_out/bias', PartitionSpec()),
|
228 |
+
('ln_[0-9]+/bias', PartitionSpec()),
|
229 |
+
('[0-9]+/ln_[0-9]+/scale', PartitionSpec()),
|
230 |
+
('ln_f/bias', PartitionSpec()),
|
231 |
+
('ln_f/scale', PartitionSpec()),
|
232 |
+
('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
|
233 |
+
('lm_head/bias', PartitionSpec('mp')),
|
234 |
+
('.*', PartitionSpec()),
|
235 |
+
)
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def get_weight_decay_exclusions():
|
239 |
+
return (
|
240 |
+
'ln_[0-9]+/bias', 'ln_[0-9]+/scale', 'ln_f/bias', 'ln_f/scale',
|
241 |
+
'bias'
|
242 |
+
)
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def rng_keys():
|
246 |
+
return ('params', 'dropout', 'fcm')
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def get_tokenizer_config(updates=None):
|
250 |
+
config = ConfigDict()
|
251 |
+
config.name = 'EleutherAI/gpt-j-6B'
|
252 |
+
config.bos_token = '<|endoftext|>'
|
253 |
+
config.eos_token = '<|endoftext|>'
|
254 |
+
config.pad_token = '<|extratoken_40|>'
|
255 |
+
config.cls_token = '<|extratoken_41|>'
|
256 |
+
config.mask_token = '<|extratoken_42|>'
|
257 |
+
|
258 |
+
if updates is not None:
|
259 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
260 |
+
|
261 |
+
return config
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
|
265 |
+
config = cls.get_tokenizer_config(config)
|
266 |
+
return AutoTokenizer.from_pretrained(
|
267 |
+
config.name,
|
268 |
+
bos_token=config.bos_token,
|
269 |
+
eos_token=config.eos_token,
|
270 |
+
pad_token=config.pad_token,
|
271 |
+
cls_token=config.cls_token,
|
272 |
+
mask_token=config.mask_token,
|
273 |
+
padding_side=padding_side,
|
274 |
+
truncation_side=truncation_side,
|
275 |
+
)
|
276 |
+
|
277 |
+
@staticmethod
|
278 |
+
def load_pretrained(name, dtype=jnp.float32):
|
279 |
+
with jax.default_device(jax.devices("cpu")[0]):
|
280 |
+
params = FlaxGPTJForCausalLM.from_pretrained(
|
281 |
+
name, _do_init=False, dtype=dtype
|
282 |
+
)[1]
|
283 |
+
params = freeze({'params': params})
|
284 |
+
return jax.device_get(params)
|
285 |
+
|
286 |
+
@classmethod
|
287 |
+
def load_config(cls, path):
|
288 |
+
if path in GPTJ_STANDARD_CONFIGS:
|
289 |
+
return cls.from_dict(GPTJ_STANDARD_CONFIGS[path])
|
290 |
+
load_type, load_path = path.split('::', 1)
|
291 |
+
if load_type == 'pickle':
|
292 |
+
return cls.from_dict(load_pickle(load_path)['gptj_config'])
|
293 |
+
elif load_type == 'json':
|
294 |
+
with open_file(load_path, 'r') as fin:
|
295 |
+
raw_config = fin.read()
|
296 |
+
return cls.from_dict(json.loads(raw_config))
|
297 |
+
elif load_type == 'huggingface':
|
298 |
+
return cls.from_pretrained(load_path)
|
299 |
+
else:
|
300 |
+
raise ValueError(f'Unsupported load config type: {load_type}')
|
301 |
+
|
302 |
+
|
303 |
+
"""
|
304 |
+
The follow code is taken from
|
305 |
+
transformers/src/transformers/models/gptj/modeling_flax_gptj.py
|
306 |
+
and modified to work with EasyLM.
|
307 |
+
"""
|
308 |
+
|
309 |
+
logger = logging.get_logger(__name__)
|
310 |
+
|
311 |
+
_CHECKPOINT_FOR_DOC = "gptj"
|
312 |
+
_CONFIG_FOR_DOC = "GPTJConfig"
|
313 |
+
|
314 |
+
remat = nn_partitioning.remat
|
315 |
+
|
316 |
+
|
317 |
+
GPTJ_START_DOCSTRING = r"""
|
318 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
319 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
320 |
+
etc.)
|
321 |
+
This model is also a Flax Linen
|
322 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
323 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
324 |
+
Finally, this model supports inherent JAX features such as:
|
325 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
326 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
327 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
328 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
329 |
+
Parameters:
|
330 |
+
config ([`GPTJConfig`]): Model configuration class with all the parameters of the model.
|
331 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
332 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
333 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
334 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
335 |
+
`jax.numpy.bfloat16` (on TPUs).
|
336 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
337 |
+
specified all the computation will be performed with the given `dtype`.
|
338 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
339 |
+
parameters.**
|
340 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
341 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
342 |
+
"""
|
343 |
+
|
344 |
+
GPTJ_INPUTS_DOCSTRING = r"""
|
345 |
+
Args:
|
346 |
+
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
|
347 |
+
`input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
|
348 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
349 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
350 |
+
[What are input IDs?](../glossary#input-ids)
|
351 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
352 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
353 |
+
- 1 for tokens that are **not masked**,
|
354 |
+
- 0 for tokens that are **masked**.
|
355 |
+
[What are attention masks?](../glossary#attention-mask)
|
356 |
+
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
357 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
358 |
+
config.max_position_embeddings - 1]`.
|
359 |
+
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
360 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
361 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
362 |
+
output_attentions (`bool`, *optional*):
|
363 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
364 |
+
tensors for more detail.
|
365 |
+
output_hidden_states (`bool`, *optional*):
|
366 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
367 |
+
more detail.
|
368 |
+
return_dict (`bool`, *optional*):
|
369 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
370 |
+
"""
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
def create_sinusoidal_positions(num_pos, dim):
|
375 |
+
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
|
376 |
+
sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
|
377 |
+
sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)
|
378 |
+
|
379 |
+
sentinel = dim // 2 + dim % 2
|
380 |
+
out = np.zeros((num_pos, dim))
|
381 |
+
out[:, 0:sentinel] = sin
|
382 |
+
out[:, sentinel:] = cos
|
383 |
+
|
384 |
+
return jnp.array(out)
|
385 |
+
|
386 |
+
|
387 |
+
def rotate_every_two(tensor):
|
388 |
+
rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
|
389 |
+
rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
|
390 |
+
return rotate_half_tensor
|
391 |
+
|
392 |
+
|
393 |
+
def apply_rotary_pos_emb(tensor, sincos):
|
394 |
+
sin_pos, cos_pos = sincos
|
395 |
+
sin_pos = sin_pos[:, :, None, :].repeat(2, 3)
|
396 |
+
cos_pos = cos_pos[:, :, None, :].repeat(2, 3)
|
397 |
+
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
|
398 |
+
|
399 |
+
|
400 |
+
class FlaxGPTJAttention(nn.Module):
|
401 |
+
config: GPTJConfig
|
402 |
+
dtype: jnp.dtype = jnp.float32
|
403 |
+
causal: bool = True
|
404 |
+
is_cross_attention: bool = False
|
405 |
+
|
406 |
+
def setup(self):
|
407 |
+
config = self.config
|
408 |
+
self.embed_dim = config.hidden_size
|
409 |
+
self.num_heads = config.num_attention_heads
|
410 |
+
self.head_dim = self.embed_dim // self.num_heads
|
411 |
+
|
412 |
+
self.rotary_dim = config.rotary_dim
|
413 |
+
|
414 |
+
dense = partial(
|
415 |
+
nn.Dense,
|
416 |
+
self.embed_dim,
|
417 |
+
use_bias=False,
|
418 |
+
dtype=self.dtype,
|
419 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
420 |
+
scale=1.0, mode='fan_in',
|
421 |
+
distribution='normal',
|
422 |
+
)
|
423 |
+
)
|
424 |
+
|
425 |
+
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
426 |
+
self.out_proj = dense()
|
427 |
+
|
428 |
+
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
429 |
+
|
430 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
431 |
+
|
432 |
+
if self.rotary_dim is not None and self.rotary_dim > 0:
|
433 |
+
pos_embd_dim = self.rotary_dim
|
434 |
+
else:
|
435 |
+
pos_embd_dim = self.embed_dim // self.num_heads
|
436 |
+
self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim)
|
437 |
+
|
438 |
+
def _split_heads(self, hidden_states):
|
439 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
440 |
+
|
441 |
+
def _merge_heads(self, hidden_states):
|
442 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
443 |
+
|
444 |
+
@nn.compact
|
445 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
446 |
+
"""
|
447 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
448 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
449 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
450 |
+
"""
|
451 |
+
# detect if we're initializing by absence of existing cache data.
|
452 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
453 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
454 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
455 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
456 |
+
|
457 |
+
if is_initialized:
|
458 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
459 |
+
# update key, value caches with our new 1d spatial slices
|
460 |
+
cur_index = cache_index.value
|
461 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
462 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
463 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
464 |
+
cached_key.value = key
|
465 |
+
cached_value.value = value
|
466 |
+
num_updated_cache_vectors = query.shape[1]
|
467 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
468 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
469 |
+
pad_mask = jnp.broadcast_to(
|
470 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
471 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
472 |
+
)
|
473 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
474 |
+
return key, value, attention_mask
|
475 |
+
|
476 |
+
def __call__(
|
477 |
+
self,
|
478 |
+
hidden_states,
|
479 |
+
attention_mask,
|
480 |
+
position_ids,
|
481 |
+
deterministic: bool = True,
|
482 |
+
init_cache: bool = False,
|
483 |
+
output_attentions: bool = False,
|
484 |
+
fcm_mask=None,
|
485 |
+
):
|
486 |
+
|
487 |
+
query = self.q_proj(hidden_states)
|
488 |
+
key = self.k_proj(hidden_states)
|
489 |
+
value = self.v_proj(hidden_states)
|
490 |
+
|
491 |
+
query = self._split_heads(query)
|
492 |
+
key = self._split_heads(key)
|
493 |
+
value = self._split_heads(value)
|
494 |
+
|
495 |
+
sincos = jnp.take(self.embed_positions, position_ids, axis=0)
|
496 |
+
sincos = jnp.split(sincos, 2, axis=-1)
|
497 |
+
# Rotary position embeddings induce some weird issues in multi-host environments, so we remove activation-sharding for keys/query vectors to fix this.
|
498 |
+
# key = with_sharding_constraint(key, PartitionSpec("dp", None, None, None))
|
499 |
+
# query = with_sharding_constraint(query, PartitionSpec("dp", None, None, None))
|
500 |
+
if self.rotary_dim is not None and self.rotary_dim > 0:
|
501 |
+
k_rot = key[:, :, :, : self.rotary_dim]
|
502 |
+
k_pass = key[:, :, :, self.rotary_dim :]
|
503 |
+
|
504 |
+
q_rot = query[:, :, :, : self.rotary_dim]
|
505 |
+
q_pass = query[:, :, :, self.rotary_dim :]
|
506 |
+
|
507 |
+
k_rot = apply_rotary_pos_emb(k_rot, sincos)
|
508 |
+
q_rot = apply_rotary_pos_emb(q_rot, sincos)
|
509 |
+
|
510 |
+
key = jnp.concatenate([k_rot, k_pass], axis=-1)
|
511 |
+
query = jnp.concatenate([q_rot, q_pass], axis=-1)
|
512 |
+
else:
|
513 |
+
key = apply_rotary_pos_emb(key, sincos)
|
514 |
+
query = apply_rotary_pos_emb(query, sincos)
|
515 |
+
|
516 |
+
query_length, key_length = query.shape[1], key.shape[1]
|
517 |
+
|
518 |
+
if self.has_variable("cache", "cached_key"):
|
519 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
520 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
521 |
+
causal_mask = lax.dynamic_slice(
|
522 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
526 |
+
|
527 |
+
batch_size = hidden_states.shape[0]
|
528 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
529 |
+
|
530 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
531 |
+
if self.causal:
|
532 |
+
attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
|
533 |
+
else:
|
534 |
+
attention_mask = attention_mask
|
535 |
+
|
536 |
+
dropout_rng = None
|
537 |
+
if not deterministic and self.config.attn_pdrop > 0.0:
|
538 |
+
dropout_rng = self.make_rng("dropout")
|
539 |
+
|
540 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
541 |
+
# and cache the keys and values step by step.
|
542 |
+
if self.has_variable("cache", "cached_key") or init_cache:
|
543 |
+
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
544 |
+
|
545 |
+
# transform boolean mask into float mask
|
546 |
+
attention_bias = lax.select(
|
547 |
+
attention_mask > 0,
|
548 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
549 |
+
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
550 |
+
)
|
551 |
+
|
552 |
+
# usual dot product attention
|
553 |
+
attn_weights = dot_product_attention_weights(
|
554 |
+
query,
|
555 |
+
key,
|
556 |
+
bias=attention_bias,
|
557 |
+
dropout_rng=dropout_rng,
|
558 |
+
dropout_rate=self.config.attn_pdrop,
|
559 |
+
deterministic=deterministic,
|
560 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32),
|
561 |
+
precision=None,
|
562 |
+
)
|
563 |
+
|
564 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
565 |
+
attn_output = self._merge_heads(attn_output)
|
566 |
+
attn_output = self.out_proj(attn_output)
|
567 |
+
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
568 |
+
|
569 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
570 |
+
return outputs
|
571 |
+
|
572 |
+
|
573 |
+
class FlaxGPTJMLP(nn.Module):
|
574 |
+
config: GPTJConfig
|
575 |
+
intermediate_size: int
|
576 |
+
dtype: jnp.dtype = jnp.float32
|
577 |
+
|
578 |
+
def setup(self):
|
579 |
+
embed_dim = self.config.hidden_size
|
580 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
581 |
+
scale=1.0, mode='fan_in',
|
582 |
+
distribution='normal',
|
583 |
+
)
|
584 |
+
|
585 |
+
self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
|
586 |
+
self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
|
587 |
+
|
588 |
+
self.act = ACT2FN[self.config.activation_function]
|
589 |
+
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
590 |
+
|
591 |
+
def __call__(self, hidden_states, deterministic: bool = True):
|
592 |
+
hidden_states = self.fc_in(hidden_states)
|
593 |
+
hidden_states = self.act(hidden_states)
|
594 |
+
hidden_states = self.fc_out(hidden_states)
|
595 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
596 |
+
return hidden_states
|
597 |
+
|
598 |
+
|
599 |
+
class FlaxGPTJBlock(nn.Module):
|
600 |
+
config: GPTJConfig
|
601 |
+
dtype: jnp.dtype = jnp.float32
|
602 |
+
|
603 |
+
def setup(self):
|
604 |
+
hidden_size = self.config.hidden_size
|
605 |
+
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
606 |
+
|
607 |
+
self.ln_1 = nn.LayerNorm(
|
608 |
+
epsilon=self.config.layer_norm_epsilon,
|
609 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32)
|
610 |
+
)
|
611 |
+
self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype)
|
612 |
+
|
613 |
+
self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype)
|
614 |
+
|
615 |
+
def __call__(
|
616 |
+
self,
|
617 |
+
hidden_states,
|
618 |
+
attention_mask=None,
|
619 |
+
position_ids=None,
|
620 |
+
deterministic: bool = True,
|
621 |
+
init_cache: bool = False,
|
622 |
+
output_attentions: bool = False,
|
623 |
+
fcm_mask=None,
|
624 |
+
):
|
625 |
+
residual = hidden_states
|
626 |
+
hidden_states = self.ln_1(hidden_states)
|
627 |
+
attn_outputs = self.attn(
|
628 |
+
hidden_states,
|
629 |
+
attention_mask=attention_mask,
|
630 |
+
position_ids=position_ids,
|
631 |
+
deterministic=deterministic,
|
632 |
+
init_cache=init_cache,
|
633 |
+
output_attentions=output_attentions,
|
634 |
+
fcm_mask=fcm_mask,
|
635 |
+
)
|
636 |
+
attn_output = attn_outputs[0]
|
637 |
+
|
638 |
+
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
|
639 |
+
# residual connection
|
640 |
+
hidden_states = attn_output + feed_forward_hidden_states + residual
|
641 |
+
|
642 |
+
return (hidden_states,) + attn_outputs[1:]
|
643 |
+
|
644 |
+
|
645 |
+
class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
646 |
+
"""
|
647 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
648 |
+
models.
|
649 |
+
"""
|
650 |
+
|
651 |
+
config_class = GPTJConfig
|
652 |
+
base_model_prefix = "transformer"
|
653 |
+
module_class: nn.Module = None
|
654 |
+
|
655 |
+
def __init__(
|
656 |
+
self,
|
657 |
+
config: GPTJConfig,
|
658 |
+
input_shape: Tuple = (1, 1),
|
659 |
+
seed: int = 0,
|
660 |
+
dtype: jnp.dtype = jnp.float32,
|
661 |
+
_do_init: bool = True,
|
662 |
+
**kwargs,
|
663 |
+
):
|
664 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
665 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
666 |
+
|
667 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
668 |
+
# init input tensors
|
669 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
670 |
+
attention_mask = jnp.ones_like(input_ids)
|
671 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
672 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
673 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
674 |
+
|
675 |
+
if self.config.add_cross_attention:
|
676 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
677 |
+
encoder_attention_mask = attention_mask
|
678 |
+
module_init_outputs = self.module.init(
|
679 |
+
rngs,
|
680 |
+
input_ids,
|
681 |
+
attention_mask,
|
682 |
+
position_ids,
|
683 |
+
encoder_hidden_states,
|
684 |
+
encoder_attention_mask,
|
685 |
+
return_dict=False,
|
686 |
+
)
|
687 |
+
else:
|
688 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
689 |
+
|
690 |
+
random_params = module_init_outputs["params"]
|
691 |
+
|
692 |
+
if params is not None:
|
693 |
+
random_params = flatten_dict(unfreeze(random_params))
|
694 |
+
params = flatten_dict(unfreeze(params))
|
695 |
+
for missing_key in self._missing_keys:
|
696 |
+
params[missing_key] = random_params[missing_key]
|
697 |
+
self._missing_keys = set()
|
698 |
+
return freeze(unflatten_dict(params))
|
699 |
+
else:
|
700 |
+
return random_params
|
701 |
+
|
702 |
+
def init_cache(self, batch_size, max_length):
|
703 |
+
r"""
|
704 |
+
Args:
|
705 |
+
batch_size (`int`):
|
706 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
707 |
+
max_length (`int`):
|
708 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
709 |
+
cache.
|
710 |
+
"""
|
711 |
+
# init input variables to retrieve cache
|
712 |
+
input_ids = jnp.ones((batch_size, max_length))
|
713 |
+
attention_mask = jnp.ones_like(input_ids)
|
714 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
715 |
+
|
716 |
+
init_variables = self.module.init(
|
717 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
718 |
+
)
|
719 |
+
return init_variables["cache"]
|
720 |
+
|
721 |
+
def _get_logits_processor(self,*args, **kwargs) -> FlaxLogitsProcessorList:
|
722 |
+
processors = super()._get_logits_processor(*args, **kwargs)
|
723 |
+
def squash_extra_tokens(input_ids, scores, cur_len):
|
724 |
+
return scores.at[:, self.config.n_real_tokens:].set(-float('inf'))
|
725 |
+
|
726 |
+
processors.append(squash_extra_tokens)
|
727 |
+
return processors
|
728 |
+
|
729 |
+
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
|
730 |
+
def __call__(
|
731 |
+
self,
|
732 |
+
input_ids,
|
733 |
+
attention_mask=None,
|
734 |
+
position_ids=None,
|
735 |
+
params: dict = None,
|
736 |
+
past_key_values: dict = None,
|
737 |
+
dropout_rng: jax.random.PRNGKey = None,
|
738 |
+
train: bool = False,
|
739 |
+
output_attentions: Optional[bool] = None,
|
740 |
+
output_hidden_states: Optional[bool] = None,
|
741 |
+
return_dict: Optional[bool] = None,
|
742 |
+
):
|
743 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
744 |
+
output_hidden_states = (
|
745 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
746 |
+
)
|
747 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
748 |
+
|
749 |
+
batch_size, sequence_length = input_ids.shape
|
750 |
+
|
751 |
+
if position_ids is None:
|
752 |
+
if past_key_values is not None:
|
753 |
+
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
754 |
+
|
755 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
756 |
+
|
757 |
+
if attention_mask is None:
|
758 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
759 |
+
|
760 |
+
# Handle any PRNG if needed
|
761 |
+
rngs = {}
|
762 |
+
if dropout_rng is not None:
|
763 |
+
rngs["dropout"] = dropout_rng
|
764 |
+
|
765 |
+
inputs = {"params": params or self.params}
|
766 |
+
|
767 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
|
768 |
+
if past_key_values:
|
769 |
+
inputs["cache"] = past_key_values
|
770 |
+
mutable = ["cache"]
|
771 |
+
else:
|
772 |
+
mutable = False
|
773 |
+
|
774 |
+
outputs = self.module.apply(
|
775 |
+
inputs,
|
776 |
+
jnp.array(input_ids, dtype="i4"),
|
777 |
+
jnp.array(attention_mask, dtype="i4"),
|
778 |
+
jnp.array(position_ids, dtype="i4"),
|
779 |
+
not train,
|
780 |
+
False,
|
781 |
+
output_attentions,
|
782 |
+
output_hidden_states,
|
783 |
+
return_dict,
|
784 |
+
rngs=rngs,
|
785 |
+
mutable=mutable,
|
786 |
+
)
|
787 |
+
|
788 |
+
# add updated cache to model output
|
789 |
+
if past_key_values is not None and return_dict:
|
790 |
+
outputs, past_key_values = outputs
|
791 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
792 |
+
return outputs
|
793 |
+
elif past_key_values is not None and not return_dict:
|
794 |
+
outputs, past_key_values = outputs
|
795 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
796 |
+
|
797 |
+
return outputs
|
798 |
+
|
799 |
+
|
800 |
+
class FlaxGPTJBlockCollection(nn.Module):
|
801 |
+
config: GPTJConfig
|
802 |
+
dtype: jnp.dtype = jnp.float32
|
803 |
+
|
804 |
+
def setup(self):
|
805 |
+
block = FlaxGPTJBlock
|
806 |
+
if self.config.gradient_checkpointing:
|
807 |
+
FlaxGPT2CheckpointBlock = remat(
|
808 |
+
block, static_argnums=(3, 4, 5),
|
809 |
+
policy=get_gradient_checkpoint_policy(
|
810 |
+
self.config.gradient_checkpointing_policy
|
811 |
+
)
|
812 |
+
)
|
813 |
+
block = FlaxGPT2CheckpointBlock
|
814 |
+
self.blocks = [
|
815 |
+
block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
816 |
+
]
|
817 |
+
|
818 |
+
def __call__(
|
819 |
+
self,
|
820 |
+
hidden_states,
|
821 |
+
attention_mask=None,
|
822 |
+
position_ids=None,
|
823 |
+
deterministic: bool = True,
|
824 |
+
init_cache: bool = False,
|
825 |
+
output_attentions: bool = False,
|
826 |
+
output_hidden_states: bool = False,
|
827 |
+
return_dict: bool = True,
|
828 |
+
):
|
829 |
+
all_attentions = () if output_attentions else None
|
830 |
+
all_hidden_states = () if output_hidden_states else None
|
831 |
+
|
832 |
+
if not deterministic and self.config.fcm_max_ratio > 0:
|
833 |
+
# Apply forgetful causal mask
|
834 |
+
batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
|
835 |
+
fcm_ratio = jax.random.uniform(
|
836 |
+
self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
|
837 |
+
minval=self.config.fcm_min_ratio,
|
838 |
+
maxval=self.config.fcm_max_ratio
|
839 |
+
)
|
840 |
+
fcm_mask = jax.random.uniform(
|
841 |
+
self.make_rng('fcm'),
|
842 |
+
shape=(batch_size, 1, seq_length, seq_length)
|
843 |
+
) > fcm_ratio
|
844 |
+
fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
|
845 |
+
fcm_mask = fcm_mask.astype('bool')
|
846 |
+
else:
|
847 |
+
fcm_mask = None
|
848 |
+
|
849 |
+
for block in self.blocks:
|
850 |
+
if output_hidden_states:
|
851 |
+
all_hidden_states += (hidden_states,)
|
852 |
+
|
853 |
+
layer_outputs = block(
|
854 |
+
hidden_states,
|
855 |
+
attention_mask,
|
856 |
+
position_ids,
|
857 |
+
deterministic,
|
858 |
+
init_cache,
|
859 |
+
output_attentions,
|
860 |
+
fcm_mask,
|
861 |
+
)
|
862 |
+
hidden_states = layer_outputs[0]
|
863 |
+
|
864 |
+
if output_attentions:
|
865 |
+
all_attentions += (layer_outputs[1],)
|
866 |
+
|
867 |
+
# this contains possible `None` values - `FlaxGPTJModule` will filter them out
|
868 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
869 |
+
|
870 |
+
return outputs
|
871 |
+
|
872 |
+
|
873 |
+
class FlaxGPTJModule(nn.Module):
|
874 |
+
config: GPTJConfig
|
875 |
+
dtype: jnp.dtype = jnp.float32
|
876 |
+
|
877 |
+
def setup(self):
|
878 |
+
self.embed_dim = self.config.hidden_size
|
879 |
+
|
880 |
+
self.wte = nn.Embed(
|
881 |
+
self.config.vocab_size,
|
882 |
+
self.config.hidden_size,
|
883 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
884 |
+
)
|
885 |
+
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
886 |
+
self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype)
|
887 |
+
self.ln_f = nn.LayerNorm(
|
888 |
+
epsilon=self.config.layer_norm_epsilon,
|
889 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32)
|
890 |
+
)
|
891 |
+
|
892 |
+
def __call__(
|
893 |
+
self,
|
894 |
+
input_ids,
|
895 |
+
attention_mask,
|
896 |
+
position_ids,
|
897 |
+
deterministic=True,
|
898 |
+
init_cache: bool = False,
|
899 |
+
output_attentions: bool = False,
|
900 |
+
output_hidden_states: bool = False,
|
901 |
+
return_dict: bool = True,
|
902 |
+
):
|
903 |
+
input_embeds = self.wte(input_ids.astype("i4"))
|
904 |
+
|
905 |
+
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
|
906 |
+
|
907 |
+
outputs = self.h(
|
908 |
+
hidden_states,
|
909 |
+
attention_mask,
|
910 |
+
position_ids=position_ids,
|
911 |
+
deterministic=deterministic,
|
912 |
+
init_cache=init_cache,
|
913 |
+
output_attentions=output_attentions,
|
914 |
+
output_hidden_states=output_hidden_states,
|
915 |
+
return_dict=return_dict,
|
916 |
+
)
|
917 |
+
|
918 |
+
hidden_states = outputs[0]
|
919 |
+
hidden_states = self.ln_f(hidden_states)
|
920 |
+
|
921 |
+
if output_hidden_states:
|
922 |
+
all_hidden_states = outputs[1] + (hidden_states,)
|
923 |
+
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
924 |
+
else:
|
925 |
+
outputs = (hidden_states,) + outputs[1:]
|
926 |
+
|
927 |
+
if not return_dict:
|
928 |
+
return tuple(v for v in outputs if v is not None)
|
929 |
+
|
930 |
+
return FlaxBaseModelOutput(
|
931 |
+
last_hidden_state=hidden_states,
|
932 |
+
hidden_states=outputs[1],
|
933 |
+
attentions=outputs[-1],
|
934 |
+
)
|
935 |
+
|
936 |
+
|
937 |
+
@add_start_docstrings(
|
938 |
+
"The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.",
|
939 |
+
GPTJ_START_DOCSTRING,
|
940 |
+
)
|
941 |
+
class FlaxGPTJModel(FlaxGPTJPreTrainedModel):
|
942 |
+
module_class = FlaxGPTJModule
|
943 |
+
|
944 |
+
|
945 |
+
append_call_sample_docstring(
|
946 |
+
FlaxGPTJModel,
|
947 |
+
_CHECKPOINT_FOR_DOC,
|
948 |
+
FlaxCausalLMOutput,
|
949 |
+
_CONFIG_FOR_DOC,
|
950 |
+
)
|
951 |
+
|
952 |
+
|
953 |
+
class FlaxGPTJForCausalLMModule(nn.Module):
|
954 |
+
config: GPTJConfig
|
955 |
+
dtype: jnp.dtype = jnp.float32
|
956 |
+
|
957 |
+
def setup(self):
|
958 |
+
self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype)
|
959 |
+
self.lm_head = nn.Dense(
|
960 |
+
self.config.vocab_size,
|
961 |
+
dtype=self.dtype,
|
962 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
963 |
+
scale=1.0, mode='fan_in',
|
964 |
+
distribution='normal',
|
965 |
+
)
|
966 |
+
)
|
967 |
+
|
968 |
+
def __call__(
|
969 |
+
self,
|
970 |
+
input_ids,
|
971 |
+
attention_mask=None,
|
972 |
+
position_ids=None,
|
973 |
+
deterministic: bool = True,
|
974 |
+
init_cache: bool = False,
|
975 |
+
output_attentions: bool = False,
|
976 |
+
output_hidden_states: bool = False,
|
977 |
+
return_dict: bool = True,
|
978 |
+
):
|
979 |
+
batch_size, seq_length = input_ids.shape
|
980 |
+
if attention_mask is None:
|
981 |
+
attention_mask = jnp.ones_like(input_ids)
|
982 |
+
if position_ids is None:
|
983 |
+
position_ids = jnp.broadcast_to(
|
984 |
+
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
985 |
+
(batch_size, seq_length)
|
986 |
+
)
|
987 |
+
|
988 |
+
outputs = self.transformer(
|
989 |
+
input_ids,
|
990 |
+
attention_mask,
|
991 |
+
position_ids,
|
992 |
+
deterministic=deterministic,
|
993 |
+
init_cache=init_cache,
|
994 |
+
output_attentions=output_attentions,
|
995 |
+
output_hidden_states=output_hidden_states,
|
996 |
+
return_dict=return_dict,
|
997 |
+
)
|
998 |
+
|
999 |
+
hidden_states = outputs[0]
|
1000 |
+
|
1001 |
+
if self.config.tie_word_embeddings:
|
1002 |
+
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
|
1003 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
1004 |
+
else:
|
1005 |
+
lm_logits = self.lm_head(hidden_states)
|
1006 |
+
|
1007 |
+
if not return_dict:
|
1008 |
+
return (lm_logits,) + outputs[1:]
|
1009 |
+
|
1010 |
+
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
1011 |
+
|
1012 |
+
|
1013 |
+
@add_start_docstrings(
|
1014 |
+
"""
|
1015 |
+
The GPTJ Model transformer with a language modeling head on top.
|
1016 |
+
""",
|
1017 |
+
GPTJ_START_DOCSTRING,
|
1018 |
+
)
|
1019 |
+
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
|
1020 |
+
module_class = FlaxGPTJForCausalLMModule
|
1021 |
+
|
1022 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
1023 |
+
# initializing the cache
|
1024 |
+
batch_size, seq_length = input_ids.shape
|
1025 |
+
|
1026 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
1027 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
1028 |
+
# But since GPTJ uses a causal mask, those positions are masked anyways.
|
1029 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
1030 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
1031 |
+
if attention_mask is not None:
|
1032 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
1033 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
1034 |
+
else:
|
1035 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
1036 |
+
|
1037 |
+
return {
|
1038 |
+
"past_key_values": past_key_values,
|
1039 |
+
"attention_mask": extended_attention_mask,
|
1040 |
+
"position_ids": position_ids,
|
1041 |
+
}
|
1042 |
+
|
1043 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
1044 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
1045 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
1046 |
+
return model_kwargs
|
1047 |
+
|
1048 |
+
|
1049 |
+
append_call_sample_docstring(
|
1050 |
+
FlaxGPTJForCausalLM,
|
1051 |
+
_CHECKPOINT_FOR_DOC,
|
1052 |
+
FlaxCausalLMOutput,
|
1053 |
+
_CONFIG_FOR_DOC,
|
1054 |
+
)
|
EasyLM/models/gptj/gptj_serve.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import mlxu
|
6 |
+
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
from jax.experimental.pjit import pjit
|
10 |
+
from jax.sharding import PartitionSpec as PS
|
11 |
+
import flax
|
12 |
+
from flax import linen as nn
|
13 |
+
from flax.jax_utils import prefetch_to_device
|
14 |
+
from flax.training.train_state import TrainState
|
15 |
+
import optax
|
16 |
+
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
17 |
+
|
18 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
19 |
+
from EasyLM.serving import LMServer
|
20 |
+
from EasyLM.jax_utils import (
|
21 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
22 |
+
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
23 |
+
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
24 |
+
)
|
25 |
+
from EasyLM.models.gptj.gptj_model import (
|
26 |
+
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
31 |
+
seed=42,
|
32 |
+
initialize_jax_distributed=False,
|
33 |
+
mesh_dim='1,-1,1',
|
34 |
+
dtype='bf16',
|
35 |
+
input_length=1024,
|
36 |
+
seq_length=2048,
|
37 |
+
top_k=50,
|
38 |
+
top_p=1.0,
|
39 |
+
do_sample=True,
|
40 |
+
num_beams=1,
|
41 |
+
add_bos_token=False,
|
42 |
+
load_gptj_config='',
|
43 |
+
load_checkpoint='',
|
44 |
+
tokenizer=GPTJConfig.get_tokenizer_config(),
|
45 |
+
lm_server=LMServer.get_default_config(),
|
46 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def main(argv):
|
51 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
52 |
+
set_random_seed(FLAGS.seed)
|
53 |
+
|
54 |
+
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
55 |
+
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
56 |
+
)
|
57 |
+
tokenizer = GPTJConfig.get_tokenizer(
|
58 |
+
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
59 |
+
)
|
60 |
+
|
61 |
+
with jax.default_device(jax.devices("cpu")[0]):
|
62 |
+
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
63 |
+
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
64 |
+
if load_type == 'huggingface':
|
65 |
+
params = gptj_config.load_pretrained(load_path)
|
66 |
+
else:
|
67 |
+
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
68 |
+
FLAGS.load_checkpoint, disallow_trainstate=True
|
69 |
+
)
|
70 |
+
|
71 |
+
hf_model = FlaxGPTJForCausalLM(
|
72 |
+
gptj_config,
|
73 |
+
input_shape=(1, FLAGS.seq_length),
|
74 |
+
seed=FLAGS.seed,
|
75 |
+
_do_init=False
|
76 |
+
)
|
77 |
+
|
78 |
+
model_ps = match_partition_rules(
|
79 |
+
GPTJConfig.get_partition_rules(), params
|
80 |
+
)
|
81 |
+
shard_fns, _ = make_shard_and_gather_fns(
|
82 |
+
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
83 |
+
)
|
84 |
+
|
85 |
+
@partial(
|
86 |
+
pjit,
|
87 |
+
in_shardings=(model_ps, PS(), PS()),
|
88 |
+
out_shardings=(PS(), PS(), PS())
|
89 |
+
)
|
90 |
+
def forward_loglikelihood(params, rng, batch):
|
91 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
92 |
+
rng_generator = JaxRNG(rng)
|
93 |
+
input_tokens = batch['input_tokens']
|
94 |
+
output_tokens = batch['output_tokens']
|
95 |
+
input_mask = batch['input_mask']
|
96 |
+
output_mask = batch['output_mask']
|
97 |
+
|
98 |
+
logits = hf_model.module.apply(
|
99 |
+
params, input_tokens, attention_mask=input_mask,
|
100 |
+
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
|
101 |
+
).logits
|
102 |
+
if gptj_config.n_real_tokens is not None:
|
103 |
+
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
|
104 |
+
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
105 |
+
logits, output_tokens
|
106 |
+
)
|
107 |
+
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
108 |
+
match_count = jnp.sum(
|
109 |
+
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
110 |
+
axis=-1
|
111 |
+
)
|
112 |
+
total = jnp.sum(output_mask, axis=-1)
|
113 |
+
is_greedy = match_count == total
|
114 |
+
return loglikelihood, is_greedy, rng_generator()
|
115 |
+
|
116 |
+
|
117 |
+
@partial(
|
118 |
+
pjit,
|
119 |
+
in_shardings=(model_ps, PS(), PS(), PS()),
|
120 |
+
out_shardings=(PS(), PS())
|
121 |
+
)
|
122 |
+
def forward_generate(params, rng, batch, temperature):
|
123 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
124 |
+
rng_generator = JaxRNG(rng)
|
125 |
+
output = hf_model.generate(
|
126 |
+
batch['input_tokens'],
|
127 |
+
attention_mask=batch['attention_mask'],
|
128 |
+
params=params['params'],
|
129 |
+
prng_key=rng_generator(),
|
130 |
+
logits_processor=FlaxLogitsProcessorList(
|
131 |
+
[FlaxTemperatureLogitsWarper(temperature)]
|
132 |
+
),
|
133 |
+
generation_config=GenerationConfig(
|
134 |
+
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
135 |
+
pad_token_id=tokenizer.eos_token_id,
|
136 |
+
bos_token_id=tokenizer.bos_token_id,
|
137 |
+
eos_token_id=tokenizer.eos_token_id,
|
138 |
+
do_sample=FLAGS.do_sample,
|
139 |
+
num_beams=FLAGS.num_beams,
|
140 |
+
top_k=FLAGS.top_k,
|
141 |
+
top_p=FLAGS.top_p,
|
142 |
+
)
|
143 |
+
).sequences[:, batch['input_tokens'].shape[1]:]
|
144 |
+
return output, rng_generator()
|
145 |
+
|
146 |
+
@partial(
|
147 |
+
pjit,
|
148 |
+
in_shardings=(model_ps, PS(), PS()),
|
149 |
+
out_shardings=(PS(), PS())
|
150 |
+
)
|
151 |
+
def forward_greedy_generate(params, rng, batch):
|
152 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
153 |
+
rng_generator = JaxRNG(rng)
|
154 |
+
output = hf_model.generate(
|
155 |
+
batch['input_tokens'],
|
156 |
+
attention_mask=batch['attention_mask'],
|
157 |
+
params=params['params'],
|
158 |
+
prng_key=rng_generator(),
|
159 |
+
generation_config=GenerationConfig(
|
160 |
+
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
161 |
+
pad_token_id=tokenizer.eos_token_id,
|
162 |
+
bos_token_id=tokenizer.bos_token_id,
|
163 |
+
eos_token_id=tokenizer.eos_token_id,
|
164 |
+
do_sample=False,
|
165 |
+
num_beams=1,
|
166 |
+
)
|
167 |
+
).sequences[:, batch['input_tokens'].shape[1]:]
|
168 |
+
return output, rng_generator()
|
169 |
+
|
170 |
+
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
171 |
+
with mesh:
|
172 |
+
params = tree_apply(shard_fns, params)
|
173 |
+
sharded_rng = next_rng()
|
174 |
+
|
175 |
+
class ModelServer(LMServer):
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def loglikelihood(prefix_text, text):
|
179 |
+
nonlocal sharded_rng
|
180 |
+
prefix = prefix_tokenizer(
|
181 |
+
prefix_text,
|
182 |
+
padding='max_length',
|
183 |
+
truncation=True,
|
184 |
+
max_length=FLAGS.input_length,
|
185 |
+
return_tensors='np',
|
186 |
+
)
|
187 |
+
inputs = tokenizer(
|
188 |
+
text,
|
189 |
+
padding='max_length',
|
190 |
+
truncation=True,
|
191 |
+
max_length=FLAGS.seq_length - FLAGS.input_length,
|
192 |
+
return_tensors='np',
|
193 |
+
)
|
194 |
+
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
195 |
+
bos_tokens = np.full(
|
196 |
+
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
197 |
+
)
|
198 |
+
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
199 |
+
input_mask = np.concatenate(
|
200 |
+
[prefix.attention_mask, inputs.attention_mask], axis=1
|
201 |
+
)
|
202 |
+
if FLAGS.add_bos_token:
|
203 |
+
bos_mask = np.ones_like(input_mask[:, :1])
|
204 |
+
else:
|
205 |
+
bos_mask = np.zeros_like(input_mask[:, :1])
|
206 |
+
|
207 |
+
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
208 |
+
output_mask = np.concatenate(
|
209 |
+
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
210 |
+
)
|
211 |
+
batch = dict(
|
212 |
+
input_tokens=input_tokens,
|
213 |
+
output_tokens=output_tokens,
|
214 |
+
input_mask=input_mask,
|
215 |
+
output_mask=output_mask,
|
216 |
+
)
|
217 |
+
with mesh:
|
218 |
+
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
219 |
+
params, sharded_rng, batch
|
220 |
+
)
|
221 |
+
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
222 |
+
return loglikelihood, is_greedy
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def loglikelihood_rolling(text):
|
226 |
+
nonlocal sharded_rng
|
227 |
+
inputs = tokenizer(
|
228 |
+
text,
|
229 |
+
padding='longest',
|
230 |
+
truncation=False,
|
231 |
+
max_length=np.iinfo(np.int32).max,
|
232 |
+
return_tensors='np',
|
233 |
+
)
|
234 |
+
batch_size = inputs.input_ids.shape[0]
|
235 |
+
output_tokens = inputs.input_ids
|
236 |
+
attention_mask = inputs.attention_mask
|
237 |
+
|
238 |
+
if output_tokens.shape[1] < FLAGS.seq_length:
|
239 |
+
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
240 |
+
pad_tokens = np.full(
|
241 |
+
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
242 |
+
)
|
243 |
+
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
244 |
+
pad_mask = np.zeros(
|
245 |
+
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
246 |
+
)
|
247 |
+
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
248 |
+
|
249 |
+
bos_tokens = np.full(
|
250 |
+
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
251 |
+
)
|
252 |
+
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
253 |
+
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
254 |
+
total_seq_length = output_tokens.shape[1]
|
255 |
+
|
256 |
+
total_loglikelihood = 0.0
|
257 |
+
total_is_greedy = True
|
258 |
+
# Sliding window
|
259 |
+
for i in range(0, total_seq_length, FLAGS.seq_length):
|
260 |
+
# Last window
|
261 |
+
if i + FLAGS.seq_length > total_seq_length:
|
262 |
+
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
263 |
+
last_output_mask[:, :i - total_seq_length] = 0.0
|
264 |
+
|
265 |
+
batch = dict(
|
266 |
+
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
267 |
+
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
268 |
+
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
269 |
+
output_mask=last_output_mask,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Normal window
|
273 |
+
else:
|
274 |
+
batch = dict(
|
275 |
+
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
276 |
+
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
277 |
+
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
278 |
+
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
279 |
+
)
|
280 |
+
|
281 |
+
with mesh:
|
282 |
+
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
283 |
+
params, sharded_rng, batch
|
284 |
+
)
|
285 |
+
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
286 |
+
|
287 |
+
total_loglikelihood += loglikelihood
|
288 |
+
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
289 |
+
|
290 |
+
return total_loglikelihood, total_is_greedy
|
291 |
+
|
292 |
+
@staticmethod
|
293 |
+
def generate(text, temperature):
|
294 |
+
nonlocal sharded_rng
|
295 |
+
inputs = prefix_tokenizer(
|
296 |
+
text,
|
297 |
+
padding='max_length',
|
298 |
+
truncation=True,
|
299 |
+
max_length=FLAGS.input_length,
|
300 |
+
return_tensors='np',
|
301 |
+
)
|
302 |
+
input_tokens = inputs.input_ids
|
303 |
+
input_mask = inputs.attention_mask
|
304 |
+
if FLAGS.add_bos_token:
|
305 |
+
input_tokens[:, 0] = tokenizer.bos_token_id
|
306 |
+
input_mask[:, 0] = 1
|
307 |
+
batch = dict(
|
308 |
+
input_tokens=input_tokens,
|
309 |
+
attention_mask=input_mask,
|
310 |
+
)
|
311 |
+
with mesh:
|
312 |
+
output, sharded_rng = forward_generate(
|
313 |
+
params, sharded_rng, batch, temperature
|
314 |
+
)
|
315 |
+
output = jax.device_get(output)
|
316 |
+
output_text = []
|
317 |
+
for text in list(tokenizer.batch_decode(output)):
|
318 |
+
if tokenizer.eos_token in text:
|
319 |
+
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
320 |
+
output_text.append(text)
|
321 |
+
|
322 |
+
return output_text
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def greedy_until(prefix_text, until, max_length):
|
326 |
+
nonlocal sharded_rng
|
327 |
+
all_outputs = []
|
328 |
+
for pf, ut in zip(prefix_text, until):
|
329 |
+
if isinstance(ut, str):
|
330 |
+
ut = [ut]
|
331 |
+
total_length = 0
|
332 |
+
total_generated = ''
|
333 |
+
|
334 |
+
while total_length < max_length:
|
335 |
+
pf_tokens = tokenizer(
|
336 |
+
pf,
|
337 |
+
padding=False,
|
338 |
+
truncation=False,
|
339 |
+
max_length=np.iinfo(np.int32).max,
|
340 |
+
return_tensors='np',
|
341 |
+
)
|
342 |
+
input_tokens = pf_tokens.input_ids
|
343 |
+
attention_mask = pf_tokens.attention_mask
|
344 |
+
|
345 |
+
if input_tokens.shape[1] < FLAGS.input_length:
|
346 |
+
extra = FLAGS.input_length - input_tokens.shape[1]
|
347 |
+
pad_tokens = np.full(
|
348 |
+
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
349 |
+
)
|
350 |
+
input_tokens = np.concatenate(
|
351 |
+
[pad_tokens, input_tokens], axis=1
|
352 |
+
)
|
353 |
+
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
354 |
+
attention_mask = np.concatenate(
|
355 |
+
[pad_attention, attention_mask], axis=1
|
356 |
+
)
|
357 |
+
elif input_tokens.shape[1] > FLAGS.input_length:
|
358 |
+
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
359 |
+
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
360 |
+
|
361 |
+
if FLAGS.add_bos_token:
|
362 |
+
input_tokens[:, 0] = tokenizer.bos_token_id
|
363 |
+
attention_mask[:, 0] = 1
|
364 |
+
|
365 |
+
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
366 |
+
|
367 |
+
with mesh:
|
368 |
+
output, sharded_rng = forward_greedy_generate(
|
369 |
+
params, sharded_rng, batch
|
370 |
+
)
|
371 |
+
output = jax.device_get(output)
|
372 |
+
|
373 |
+
total_length += output.shape[1]
|
374 |
+
output_text = tokenizer.batch_decode(output)[0]
|
375 |
+
total_generated = total_generated + output_text
|
376 |
+
pf = pf + output_text
|
377 |
+
|
378 |
+
done = False
|
379 |
+
for s in ut:
|
380 |
+
if s in total_generated:
|
381 |
+
total_generated = total_generated.split(s, maxsplit=1)[0]
|
382 |
+
done = True
|
383 |
+
if done:
|
384 |
+
break
|
385 |
+
|
386 |
+
all_outputs.append(total_generated)
|
387 |
+
|
388 |
+
return all_outputs
|
389 |
+
|
390 |
+
|
391 |
+
server = ModelServer(FLAGS.lm_server)
|
392 |
+
server.run()
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
mlxu.run(main)
|
EasyLM/models/gptj/gptj_train.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from tqdm import tqdm, trange
|
5 |
+
import numpy as np
|
6 |
+
import mlxu
|
7 |
+
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
11 |
+
from jax.sharding import PartitionSpec as PS
|
12 |
+
from flax.training.train_state import TrainState
|
13 |
+
|
14 |
+
from EasyLM.data import DatasetFactory
|
15 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
+
from EasyLM.optimizers import OptimizerFactory
|
17 |
+
from EasyLM.jax_utils import (
|
18 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
19 |
+
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
+
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
+
make_shard_and_gather_fns, tree_apply
|
22 |
+
)
|
23 |
+
from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
|
24 |
+
|
25 |
+
|
26 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
27 |
+
seed=42,
|
28 |
+
mesh_dim='1,-1,1',
|
29 |
+
dtype='fp32',
|
30 |
+
total_steps=10000,
|
31 |
+
load_gptj_config='',
|
32 |
+
update_gptj_config='',
|
33 |
+
load_checkpoint='',
|
34 |
+
load_dataset_state='',
|
35 |
+
log_freq=50,
|
36 |
+
save_model_freq=0,
|
37 |
+
save_milestone_freq=0,
|
38 |
+
eval_steps=0,
|
39 |
+
tokenizer=GPTJConfig.get_tokenizer_config(),
|
40 |
+
train_dataset=DatasetFactory.get_default_config(),
|
41 |
+
eval_dataset=DatasetFactory.get_default_config(),
|
42 |
+
optimizer=OptimizerFactory.get_default_config(),
|
43 |
+
checkpointer=StreamingCheckpointer.get_default_config(),
|
44 |
+
gptj=GPTJConfig.get_default_config(),
|
45 |
+
logger=mlxu.WandBLogger.get_default_config(),
|
46 |
+
log_all_worker=False,
|
47 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def main(argv):
|
52 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
53 |
+
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
54 |
+
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
55 |
+
logger = mlxu.WandBLogger(
|
56 |
+
config=FLAGS.logger,
|
57 |
+
variant=variant,
|
58 |
+
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
59 |
+
)
|
60 |
+
set_random_seed(FLAGS.seed)
|
61 |
+
|
62 |
+
tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
|
63 |
+
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
64 |
+
if FLAGS.load_dataset_state != '':
|
65 |
+
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
66 |
+
|
67 |
+
if FLAGS.eval_steps > 0:
|
68 |
+
eval_dataset = DatasetFactory.load_dataset(
|
69 |
+
FLAGS.eval_dataset, dataset.tokenizer
|
70 |
+
)
|
71 |
+
eval_iterator = iter(eval_dataset)
|
72 |
+
|
73 |
+
seq_length = dataset.seq_length
|
74 |
+
|
75 |
+
if FLAGS.load_gptj_config != '':
|
76 |
+
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
|
77 |
+
else:
|
78 |
+
gptj_config = GPTJConfig(**FLAGS.gptj)
|
79 |
+
|
80 |
+
if FLAGS.update_gptj_config != '':
|
81 |
+
gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
|
82 |
+
|
83 |
+
gptj_config.update(dict(
|
84 |
+
bos_token_id=dataset.tokenizer.bos_token_id,
|
85 |
+
eos_token_id=dataset.tokenizer.eos_token_id,
|
86 |
+
))
|
87 |
+
if gptj_config.vocab_size < dataset.vocab_size:
|
88 |
+
gptj_config.update(dict(vocab_size=dataset.vocab_size))
|
89 |
+
|
90 |
+
model = FlaxGPTJForCausalLMModule(
|
91 |
+
gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
92 |
+
)
|
93 |
+
|
94 |
+
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
95 |
+
FLAGS.optimizer,
|
96 |
+
get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
|
97 |
+
)
|
98 |
+
|
99 |
+
def create_trainstate_from_params(params):
|
100 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
101 |
+
|
102 |
+
def init_fn(rng):
|
103 |
+
rng_generator = JaxRNG(rng)
|
104 |
+
params = model.init(
|
105 |
+
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
106 |
+
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
107 |
+
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
108 |
+
rngs=rng_generator(gptj_config.rng_keys()),
|
109 |
+
)
|
110 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
111 |
+
|
112 |
+
def train_step(train_state, rng, batch):
|
113 |
+
rng_generator = JaxRNG(rng)
|
114 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
115 |
+
def loss_and_accuracy(params):
|
116 |
+
logits = model.apply(
|
117 |
+
params, batch['input_tokens'], deterministic=False,
|
118 |
+
rngs=rng_generator(gptj_config.rng_keys()),
|
119 |
+
).logits
|
120 |
+
return cross_entropy_loss_and_accuracy(
|
121 |
+
logits, batch['target_tokens'], batch['loss_masks']
|
122 |
+
)
|
123 |
+
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
124 |
+
(loss, accuracy), grads = grad_fn(train_state.params)
|
125 |
+
train_state = train_state.apply_gradients(grads=grads)
|
126 |
+
metrics = dict(
|
127 |
+
loss=loss,
|
128 |
+
accuracy=accuracy,
|
129 |
+
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
130 |
+
gradient_norm=global_norm(grads),
|
131 |
+
param_norm=global_norm(train_state.params),
|
132 |
+
)
|
133 |
+
return train_state, rng_generator(), metrics
|
134 |
+
|
135 |
+
def eval_step(train_state, rng, batch):
|
136 |
+
rng_generator = JaxRNG(rng)
|
137 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
138 |
+
logits = model.apply(
|
139 |
+
train_state.params, batch['input_tokens'], deterministic=True,
|
140 |
+
rngs=rng_generator(gptj_config.rng_keys()),
|
141 |
+
).logits
|
142 |
+
loss, accuracy = cross_entropy_loss_and_accuracy(
|
143 |
+
logits, batch['target_tokens'], batch['loss_masks']
|
144 |
+
)
|
145 |
+
metrics = dict(
|
146 |
+
eval_loss=loss,
|
147 |
+
eval_accuracy=accuracy,
|
148 |
+
)
|
149 |
+
return rng_generator(), metrics
|
150 |
+
|
151 |
+
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
152 |
+
train_state_partition = match_partition_rules(
|
153 |
+
GPTJConfig.get_partition_rules(), train_state_shapes
|
154 |
+
)
|
155 |
+
|
156 |
+
shard_fns, gather_fns = make_shard_and_gather_fns(
|
157 |
+
train_state_partition, train_state_shapes
|
158 |
+
)
|
159 |
+
checkpointer = StreamingCheckpointer(
|
160 |
+
FLAGS.checkpointer, logger.output_dir,
|
161 |
+
enable=jax.process_index() == 0,
|
162 |
+
)
|
163 |
+
|
164 |
+
sharded_init_fn = pjit(
|
165 |
+
init_fn,
|
166 |
+
in_shardings=PS(),
|
167 |
+
out_shardings=train_state_partition
|
168 |
+
)
|
169 |
+
|
170 |
+
sharded_create_trainstate_from_params = pjit(
|
171 |
+
create_trainstate_from_params,
|
172 |
+
in_shardings=(train_state_partition.params, ),
|
173 |
+
out_shardings=train_state_partition,
|
174 |
+
donate_argnums=(0, ),
|
175 |
+
)
|
176 |
+
|
177 |
+
sharded_train_step = pjit(
|
178 |
+
train_step,
|
179 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
180 |
+
out_shardings=(train_state_partition, PS(), PS()),
|
181 |
+
donate_argnums=(0, 1),
|
182 |
+
)
|
183 |
+
|
184 |
+
sharded_eval_step = pjit(
|
185 |
+
eval_step,
|
186 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
187 |
+
out_shardings=(PS(), PS()),
|
188 |
+
donate_argnums=(1,),
|
189 |
+
)
|
190 |
+
|
191 |
+
def save_checkpoint(train_state, milestone=False):
|
192 |
+
step = int(jax.device_get(train_state.step))
|
193 |
+
metadata = dict(
|
194 |
+
step=step,
|
195 |
+
variant=variant,
|
196 |
+
flags=flags_config_dict,
|
197 |
+
gptj_config=gptj_config.to_dict(),
|
198 |
+
)
|
199 |
+
checkpointer.save_all(
|
200 |
+
train_state=train_state,
|
201 |
+
gather_fns=gather_fns,
|
202 |
+
metadata=metadata,
|
203 |
+
dataset=dataset.get_state_dict(),
|
204 |
+
milestone=milestone,
|
205 |
+
)
|
206 |
+
|
207 |
+
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
|
208 |
+
with mesh:
|
209 |
+
train_state, restored_params = None, None
|
210 |
+
if FLAGS.load_checkpoint != '':
|
211 |
+
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
212 |
+
if load_type == 'huggingface':
|
213 |
+
restored_params = tree_apply(
|
214 |
+
shard_fns.params, gptj_config.load_pretrained(load_path)
|
215 |
+
)
|
216 |
+
train_state = None
|
217 |
+
else:
|
218 |
+
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
219 |
+
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
220 |
+
)
|
221 |
+
|
222 |
+
if train_state is None and restored_params is None:
|
223 |
+
# Initialize from scratch
|
224 |
+
train_state = sharded_init_fn(next_rng())
|
225 |
+
elif train_state is None and restored_params is not None:
|
226 |
+
# Restore from params but initialize train_state
|
227 |
+
train_state = sharded_create_trainstate_from_params(restored_params)
|
228 |
+
del restored_params
|
229 |
+
|
230 |
+
start_step = int(jax.device_get(train_state.step))
|
231 |
+
|
232 |
+
if FLAGS.save_model_freq > 0:
|
233 |
+
save_checkpoint(train_state)
|
234 |
+
|
235 |
+
sharded_rng = next_rng()
|
236 |
+
|
237 |
+
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
238 |
+
|
239 |
+
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
240 |
+
train_state, sharded_rng, metrics = sharded_train_step(
|
241 |
+
train_state, sharded_rng, batch
|
242 |
+
)
|
243 |
+
|
244 |
+
if step % FLAGS.log_freq == 0:
|
245 |
+
if FLAGS.eval_steps > 0:
|
246 |
+
eval_metric_list = []
|
247 |
+
for _ in range(FLAGS.eval_steps):
|
248 |
+
eval_batch, _ = next(eval_iterator)
|
249 |
+
sharded_rng, eval_metrics = sharded_eval_step(
|
250 |
+
train_state, sharded_rng, eval_batch
|
251 |
+
)
|
252 |
+
eval_metric_list.append(eval_metrics)
|
253 |
+
metrics.update(average_metrics(eval_metric_list))
|
254 |
+
|
255 |
+
log_metrics = {"step": step}
|
256 |
+
log_metrics.update(metrics)
|
257 |
+
log_metrics.update(dataset_metrics)
|
258 |
+
log_metrics = jax.device_get(log_metrics)
|
259 |
+
logger.log(log_metrics)
|
260 |
+
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
261 |
+
|
262 |
+
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
263 |
+
save_checkpoint(train_state, milestone=True)
|
264 |
+
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
265 |
+
save_checkpoint(train_state)
|
266 |
+
|
267 |
+
if FLAGS.save_model_freq > 0:
|
268 |
+
save_checkpoint(train_state)
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
mlxu.run(main)
|
EasyLM/models/llama/convert_easylm_to_hf.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
# Copyright 2023 Xinyang Geng
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# This script converts LLaMA model checkpoint trained by EsayLM to the
|
17 |
+
# HuggingFace transformers LLaMA PyTorch format, which can then be loaded
|
18 |
+
# by HuggingFace transformers.
|
19 |
+
|
20 |
+
import gc
|
21 |
+
import json
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import shutil
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import mlxu
|
28 |
+
import jax
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import flax
|
31 |
+
from flax.traverse_util import flatten_dict
|
32 |
+
import torch
|
33 |
+
from transformers import LlamaConfig, LlamaForCausalLM
|
34 |
+
|
35 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
36 |
+
from EasyLM.jax_utils import float_tensor_to_dtype
|
37 |
+
|
38 |
+
|
39 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
40 |
+
load_checkpoint='',
|
41 |
+
tokenizer_path='',
|
42 |
+
model_size='13b',
|
43 |
+
output_dir='',
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
LLAMA_STANDARD_CONFIGS = {
|
48 |
+
'small': {
|
49 |
+
'vocab_size': 64256,
|
50 |
+
'dim': 768,
|
51 |
+
'intermediate_size': 3072,
|
52 |
+
'n_layers': 12,
|
53 |
+
'n_heads': 12,
|
54 |
+
'norm_eps': 1e-6,
|
55 |
+
},
|
56 |
+
'medium': {
|
57 |
+
'vocab_size': 64256,
|
58 |
+
'dim': 1024,
|
59 |
+
'intermediate_size': 4096,
|
60 |
+
'n_layers': 24,
|
61 |
+
'n_heads': 16,
|
62 |
+
'norm_eps': 1e-6,
|
63 |
+
},
|
64 |
+
'large': {
|
65 |
+
'vocab_size': 64256,
|
66 |
+
'dim': 1536,
|
67 |
+
'intermediate_size': 6144,
|
68 |
+
'n_layers': 24,
|
69 |
+
'n_heads': 16,
|
70 |
+
'norm_eps': 1e-6,
|
71 |
+
},
|
72 |
+
'xlarge': {
|
73 |
+
'vocab_size': 64256,
|
74 |
+
'dim': 2048,
|
75 |
+
'intermediate_size': 8192,
|
76 |
+
'n_layers': 24,
|
77 |
+
'n_heads': 32,
|
78 |
+
'norm_eps': 1e-6,
|
79 |
+
},
|
80 |
+
'1b': {
|
81 |
+
'vocab_size': 64256,
|
82 |
+
'dim': 2048,
|
83 |
+
'intermediate_size': 5504,
|
84 |
+
'n_layers': 22,
|
85 |
+
'n_heads': 16,
|
86 |
+
'norm_eps': 1e-6,
|
87 |
+
},
|
88 |
+
'3b': {
|
89 |
+
'vocab_size': 64256,
|
90 |
+
'dim': 3200,
|
91 |
+
'intermediate_size': 8640,
|
92 |
+
'n_layers': 26,
|
93 |
+
'n_heads': 32,
|
94 |
+
'norm_eps': 1e-6,
|
95 |
+
},
|
96 |
+
'7b': {
|
97 |
+
'vocab_size': 64256,
|
98 |
+
'dim': 4096,
|
99 |
+
'intermediate_size': 11008,
|
100 |
+
'n_layers': 32,
|
101 |
+
'n_heads': 32,
|
102 |
+
'norm_eps': 1e-6,
|
103 |
+
},
|
104 |
+
'13b': {
|
105 |
+
'vocab_size': 64256,
|
106 |
+
'dim': 5120,
|
107 |
+
'intermediate_size': 13824,
|
108 |
+
'n_layers': 40,
|
109 |
+
'n_heads': 40,
|
110 |
+
'norm_eps': 1e-6,
|
111 |
+
},
|
112 |
+
'30b': {
|
113 |
+
'vocab_size': 64256,
|
114 |
+
'dim': 6656,
|
115 |
+
'intermediate_size': 17920,
|
116 |
+
'n_layers': 60,
|
117 |
+
'n_heads': 52,
|
118 |
+
'norm_eps': 1e-6,
|
119 |
+
},
|
120 |
+
'65b': {
|
121 |
+
'vocab_size': 64256,
|
122 |
+
'dim': 8192,
|
123 |
+
'intermediate_size': 22016,
|
124 |
+
'n_layers': 80,
|
125 |
+
'n_heads': 64,
|
126 |
+
'norm_eps': 1e-5,
|
127 |
+
},
|
128 |
+
}
|
129 |
+
|
130 |
+
|
131 |
+
def match_keywords(string, positives, negatives):
|
132 |
+
for positive in positives:
|
133 |
+
if positive not in string:
|
134 |
+
return False
|
135 |
+
for negative in negatives:
|
136 |
+
if negative in string:
|
137 |
+
return False
|
138 |
+
return True
|
139 |
+
|
140 |
+
|
141 |
+
def load_and_convert_checkpoint(path):
|
142 |
+
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
143 |
+
flax_params = flatten_dict(flax_params['params'], sep='.')
|
144 |
+
torch_params = {}
|
145 |
+
for key, tensor in flax_params.items():
|
146 |
+
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
147 |
+
tensor = tensor.T
|
148 |
+
torch_params[key] = torch.tensor(
|
149 |
+
float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
|
150 |
+
)
|
151 |
+
return torch_params
|
152 |
+
|
153 |
+
|
154 |
+
def read_json(path):
|
155 |
+
with open(path, "r") as f:
|
156 |
+
return json.load(f)
|
157 |
+
|
158 |
+
|
159 |
+
def write_json(text, path):
|
160 |
+
with open(path, "w") as f:
|
161 |
+
json.dump(text, f)
|
162 |
+
|
163 |
+
|
164 |
+
def write_model(loaded, model_path, model_size):
|
165 |
+
os.makedirs(model_path, exist_ok=True)
|
166 |
+
tmp_model_path = os.path.join(model_path, "tmp")
|
167 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
168 |
+
|
169 |
+
params = LLAMA_STANDARD_CONFIGS[model_size]
|
170 |
+
|
171 |
+
n_layers = params["n_layers"]
|
172 |
+
n_heads = params["n_heads"]
|
173 |
+
dim = params["dim"]
|
174 |
+
dims_per_head = dim // n_heads
|
175 |
+
base = 10000.0
|
176 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
177 |
+
|
178 |
+
# permute for sliced rotary
|
179 |
+
def permute(w):
|
180 |
+
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
181 |
+
|
182 |
+
|
183 |
+
param_count = 0
|
184 |
+
index_dict = {"weight_map": {}}
|
185 |
+
for layer_i in range(n_layers):
|
186 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
187 |
+
state_dict = {
|
188 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
189 |
+
loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
|
190 |
+
),
|
191 |
+
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
192 |
+
loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
|
193 |
+
),
|
194 |
+
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
|
195 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
|
196 |
+
|
197 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
|
198 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
|
199 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
|
200 |
+
|
201 |
+
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
|
202 |
+
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
|
203 |
+
|
204 |
+
}
|
205 |
+
|
206 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
207 |
+
for k, v in state_dict.items():
|
208 |
+
index_dict["weight_map"][k] = filename
|
209 |
+
param_count += v.numel()
|
210 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
211 |
+
|
212 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
213 |
+
# Unsharded
|
214 |
+
state_dict = {
|
215 |
+
"model.embed_tokens.weight": loaded["transformer.wte.embedding"],
|
216 |
+
"model.norm.weight": loaded["transformer.ln_f.kernel"],
|
217 |
+
"lm_head.weight": loaded["lm_head.kernel"],
|
218 |
+
}
|
219 |
+
|
220 |
+
for k, v in state_dict.items():
|
221 |
+
index_dict["weight_map"][k] = filename
|
222 |
+
param_count += v.numel()
|
223 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
224 |
+
|
225 |
+
# Write configs
|
226 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
227 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
228 |
+
|
229 |
+
config = LlamaConfig(
|
230 |
+
vocab_size=params["vocab_size"],
|
231 |
+
hidden_size=dim,
|
232 |
+
intermediate_size=params["intermediate_size"],
|
233 |
+
num_attention_heads=params["n_heads"],
|
234 |
+
num_hidden_layers=params["n_layers"],
|
235 |
+
rms_norm_eps=params["norm_eps"],
|
236 |
+
)
|
237 |
+
config.save_pretrained(tmp_model_path)
|
238 |
+
|
239 |
+
# Make space so we can load the model properly now.
|
240 |
+
del state_dict
|
241 |
+
del loaded
|
242 |
+
gc.collect()
|
243 |
+
|
244 |
+
print("Loading the checkpoint in a Llama model.")
|
245 |
+
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
246 |
+
# Avoid saving this as part of the config.
|
247 |
+
print("Model parameter count", model.num_parameters())
|
248 |
+
del model.config._name_or_path
|
249 |
+
|
250 |
+
print("Saving in the Transformers format.")
|
251 |
+
model.save_pretrained(model_path, safe_serialization=True)
|
252 |
+
shutil.rmtree(tmp_model_path)
|
253 |
+
|
254 |
+
|
255 |
+
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
256 |
+
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
|
257 |
+
os.makedirs(tokenizer_path, exist_ok=True)
|
258 |
+
write_json(
|
259 |
+
{
|
260 |
+
"bos_token": {
|
261 |
+
"content": "<s>",
|
262 |
+
"lstrip": False,
|
263 |
+
"normalized": True,
|
264 |
+
"rstrip": False,
|
265 |
+
"single_word": False
|
266 |
+
},
|
267 |
+
"eos_token": {
|
268 |
+
"content": "</s>",
|
269 |
+
"lstrip": False,
|
270 |
+
"normalized": True,
|
271 |
+
"rstrip": False,
|
272 |
+
"single_word": False
|
273 |
+
},
|
274 |
+
"unk_token": {
|
275 |
+
"content": "<unk>",
|
276 |
+
"lstrip": False,
|
277 |
+
"normalized": True,
|
278 |
+
"rstrip": False,
|
279 |
+
"single_word": False
|
280 |
+
},
|
281 |
+
},
|
282 |
+
os.path.join(tokenizer_path, "special_tokens_map.json")
|
283 |
+
)
|
284 |
+
write_json(
|
285 |
+
{
|
286 |
+
"add_bos_token": True,
|
287 |
+
"add_eos_token": False,
|
288 |
+
"model_max_length": 2048,
|
289 |
+
"pad_token": None,
|
290 |
+
"sp_model_kwargs": {},
|
291 |
+
"tokenizer_class": "LlamaTokenizer",
|
292 |
+
"clean_up_tokenization_spaces": False,
|
293 |
+
"bos_token": {
|
294 |
+
"__type": "AddedToken",
|
295 |
+
"content": "<s>",
|
296 |
+
"lstrip": False,
|
297 |
+
"normalized": True,
|
298 |
+
"rstrip": False,
|
299 |
+
"single_word": False
|
300 |
+
},
|
301 |
+
"eos_token": {
|
302 |
+
"__type": "AddedToken",
|
303 |
+
"content": "</s>",
|
304 |
+
"lstrip": False,
|
305 |
+
"normalized": True,
|
306 |
+
"rstrip": False,
|
307 |
+
"single_word": False
|
308 |
+
},
|
309 |
+
"unk_token": {
|
310 |
+
"__type": "AddedToken",
|
311 |
+
"content": "<unk>",
|
312 |
+
"lstrip": False,
|
313 |
+
"normalized": True,
|
314 |
+
"rstrip": False,
|
315 |
+
"single_word": False
|
316 |
+
},
|
317 |
+
},
|
318 |
+
os.path.join(tokenizer_path, "tokenizer_config.json"),
|
319 |
+
)
|
320 |
+
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
|
321 |
+
|
322 |
+
|
323 |
+
def main(argv):
|
324 |
+
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
|
325 |
+
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
326 |
+
# write_tokenizer(
|
327 |
+
# tokenizer_path=FLAGS.output_dir,
|
328 |
+
# input_tokenizer_path=FLAGS.tokenizer_path,
|
329 |
+
# )
|
330 |
+
write_model(
|
331 |
+
load_and_convert_checkpoint(FLAGS.load_checkpoint),
|
332 |
+
model_path=FLAGS.output_dir,
|
333 |
+
model_size=FLAGS.model_size,
|
334 |
+
)
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
mlxu.run(main)
|
EasyLM/models/llama/convert_hf_to_easylm.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python convert_hf_to_easylm.py \
|
4 |
+
--checkpoint_dir /path/hf_format_dir/ \
|
5 |
+
--output_file /path/easylm_format.stream \
|
6 |
+
--model_size 7b \
|
7 |
+
--streaming
|
8 |
+
"""
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
import mlxu
|
14 |
+
import torch
|
15 |
+
import flax
|
16 |
+
|
17 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
18 |
+
|
19 |
+
LLAMA_STANDARD_CONFIGS = {
|
20 |
+
'1b': {
|
21 |
+
'dim': 2048,
|
22 |
+
'intermediate_size': 5504,
|
23 |
+
'n_layers': 22,
|
24 |
+
'n_heads': 16,
|
25 |
+
'norm_eps': 1e-6,
|
26 |
+
},
|
27 |
+
'3b': {
|
28 |
+
'dim': 3200,
|
29 |
+
'intermediate_size': 8640,
|
30 |
+
'n_layers': 26,
|
31 |
+
'n_heads': 32,
|
32 |
+
'norm_eps': 1e-6,
|
33 |
+
},
|
34 |
+
"7b": {
|
35 |
+
"dim": 4096,
|
36 |
+
"intermediate_size": 11008,
|
37 |
+
"n_layers": 32,
|
38 |
+
"n_heads": 32,
|
39 |
+
"norm_eps": 1e-6,
|
40 |
+
},
|
41 |
+
"13b": {
|
42 |
+
"dim": 5120,
|
43 |
+
"intermediate_size": 13824,
|
44 |
+
"n_layers": 40,
|
45 |
+
"n_heads": 40,
|
46 |
+
"norm_eps": 1e-6,
|
47 |
+
},
|
48 |
+
"30b": {
|
49 |
+
"dim": 6656,
|
50 |
+
"intermediate_size": 17920,
|
51 |
+
"n_layers": 60,
|
52 |
+
"n_heads": 52,
|
53 |
+
"norm_eps": 1e-6,
|
54 |
+
},
|
55 |
+
"65b": {
|
56 |
+
"dim": 8192,
|
57 |
+
"intermediate_size": 22016,
|
58 |
+
"n_layers": 80,
|
59 |
+
"n_heads": 64,
|
60 |
+
"norm_eps": 1e-5,
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
def inverse_permute(params, w):
|
66 |
+
n_layers = params["n_layers"]
|
67 |
+
n_heads = params["n_heads"]
|
68 |
+
dim = params["dim"]
|
69 |
+
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
70 |
+
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
71 |
+
inverted_w = transposed_w.reshape(dim, dim)
|
72 |
+
return inverted_w
|
73 |
+
|
74 |
+
|
75 |
+
def main(args):
|
76 |
+
start = time.time()
|
77 |
+
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
78 |
+
|
79 |
+
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
80 |
+
ckpt = {}
|
81 |
+
for i, ckpt_path in enumerate(ckpt_paths):
|
82 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
83 |
+
for k, v in checkpoint.items():
|
84 |
+
if k.startswith("model."):
|
85 |
+
k = k[6:]
|
86 |
+
ckpt[k] = v
|
87 |
+
print(f"Start convert weight to easylm format...")
|
88 |
+
jax_weights = {
|
89 |
+
"transformer": {
|
90 |
+
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
91 |
+
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
92 |
+
"h": {
|
93 |
+
"%d"
|
94 |
+
% (layer): {
|
95 |
+
"attention": {
|
96 |
+
"wq": {
|
97 |
+
"kernel": inverse_permute(
|
98 |
+
params,
|
99 |
+
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
100 |
+
).transpose()
|
101 |
+
},
|
102 |
+
"wk": {
|
103 |
+
"kernel": inverse_permute(
|
104 |
+
params,
|
105 |
+
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
106 |
+
).transpose()
|
107 |
+
},
|
108 |
+
"wv": {
|
109 |
+
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
110 |
+
.numpy()
|
111 |
+
.transpose()
|
112 |
+
},
|
113 |
+
"wo": {
|
114 |
+
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
115 |
+
.numpy()
|
116 |
+
.transpose()
|
117 |
+
},
|
118 |
+
},
|
119 |
+
"feed_forward": {
|
120 |
+
"w1": {
|
121 |
+
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
122 |
+
.numpy()
|
123 |
+
.transpose()
|
124 |
+
},
|
125 |
+
"w2": {
|
126 |
+
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
127 |
+
.numpy()
|
128 |
+
.transpose()
|
129 |
+
},
|
130 |
+
"w3": {
|
131 |
+
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
132 |
+
.numpy()
|
133 |
+
.transpose()
|
134 |
+
},
|
135 |
+
},
|
136 |
+
"attention_norm": {
|
137 |
+
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
138 |
+
},
|
139 |
+
"ffn_norm": {
|
140 |
+
"kernel": ckpt[
|
141 |
+
f"layers.{layer}.post_attention_layernorm.weight"
|
142 |
+
].numpy()
|
143 |
+
},
|
144 |
+
}
|
145 |
+
for layer in range(params["n_layers"])
|
146 |
+
},
|
147 |
+
},
|
148 |
+
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
149 |
+
}
|
150 |
+
print(f"Convert weight to easylm format finished...")
|
151 |
+
print(f"Start to save...")
|
152 |
+
|
153 |
+
if args.streaming:
|
154 |
+
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
155 |
+
else:
|
156 |
+
with mlxu.open_file(args.output_file, "wb") as fout:
|
157 |
+
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
158 |
+
|
159 |
+
print(
|
160 |
+
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"--checkpoint_dir",
|
169 |
+
type=str,
|
170 |
+
help="Need to be converted model weight dir. it is a dir",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--output_file", type=str, help="Save model weight file path, it is a file."
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--model_size",
|
177 |
+
type=str,
|
178 |
+
default="7b",
|
179 |
+
choices=["7b", "13b", "30b", "65b"],
|
180 |
+
help="model size",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--streaming",
|
184 |
+
action="store_true",
|
185 |
+
default=True,
|
186 |
+
help="whether is model weight saved stream format",
|
187 |
+
)
|
188 |
+
|
189 |
+
args = parser.parse_args()
|
190 |
+
|
191 |
+
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
192 |
+
print(f"output_file: {args.output_file}")
|
193 |
+
print(f"model_size: {args.model_size}")
|
194 |
+
print(f"streaming: {args.streaming}")
|
195 |
+
|
196 |
+
main(args)
|
EasyLM/models/llama/convert_torch_to_easylm.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script converts the standrd LLaMA PyTorch checkpoint released by Meta
|
2 |
+
# to the EasyLM checkpoint format. The converted checkpoint can then be loaded
|
3 |
+
# by EasyLM for fine-tuning or inference.
|
4 |
+
|
5 |
+
# This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import flax
|
12 |
+
import mlxu
|
13 |
+
|
14 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
15 |
+
|
16 |
+
|
17 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
18 |
+
checkpoint_dir='',
|
19 |
+
output_file='',
|
20 |
+
streaming=True,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def main(argv):
|
25 |
+
ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
|
26 |
+
ckpts = {}
|
27 |
+
for i, ckpt_path in enumerate(ckpt_paths):
|
28 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
29 |
+
ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
|
30 |
+
ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
|
31 |
+
with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
|
32 |
+
params = json.loads(f.read())
|
33 |
+
|
34 |
+
jax_weights = {
|
35 |
+
'transformer': {
|
36 |
+
'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
|
37 |
+
'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
|
38 |
+
'h': {
|
39 |
+
'%d' % (layer): {
|
40 |
+
'attention': {
|
41 |
+
'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
42 |
+
'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
43 |
+
'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
44 |
+
'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
45 |
+
},
|
46 |
+
'feed_forward': {
|
47 |
+
'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
48 |
+
'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
|
49 |
+
'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
|
50 |
+
},
|
51 |
+
'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
|
52 |
+
'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
|
53 |
+
}
|
54 |
+
for layer in range(params['n_layers'])},
|
55 |
+
},
|
56 |
+
'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
|
57 |
+
}
|
58 |
+
if FLAGS.streaming:
|
59 |
+
StreamingCheckpointer.save_train_state_to_file(
|
60 |
+
jax_weights, FLAGS.output_file
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
|
64 |
+
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == '__main__':
|
68 |
+
mlxu.run(main)
|
EasyLM/models/llama/llama_model.py
ADDED
@@ -0,0 +1,1360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from shutil import copyfile
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
import json
|
5 |
+
import tempfile
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import jax
|
10 |
+
import jax.numpy as jnp
|
11 |
+
from jax import lax
|
12 |
+
from jax.sharding import PartitionSpec as PS
|
13 |
+
import flax.linen as nn
|
14 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
15 |
+
from flax.linen import combine_masks, make_causal_mask
|
16 |
+
from flax.linen.attention import dot_product_attention_weights
|
17 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
18 |
+
from flax.linen import partitioning as nn_partitioning
|
19 |
+
import einops
|
20 |
+
|
21 |
+
import sentencepiece as spm
|
22 |
+
from transformers import AutoTokenizer
|
23 |
+
from transformers.configuration_utils import PretrainedConfig
|
24 |
+
from transformers.utils import logging
|
25 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
26 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
27 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
28 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
29 |
+
|
30 |
+
from ml_collections import ConfigDict
|
31 |
+
from ml_collections.config_dict import config_dict
|
32 |
+
from mlxu import function_args_to_config, load_pickle, open_file
|
33 |
+
|
34 |
+
from EasyLM.bpt import blockwise_ffn, blockwise_attn
|
35 |
+
from EasyLM.jax_utils import (
|
36 |
+
with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
LLAMA_STANDARD_CONFIGS = {
|
41 |
+
'small': {
|
42 |
+
'vocab_size': 64256,
|
43 |
+
'hidden_size': 768,
|
44 |
+
'intermediate_size': 3072,
|
45 |
+
'num_hidden_layers': 12,
|
46 |
+
'num_attention_heads': 12,
|
47 |
+
'max_sequence_length': 2048,
|
48 |
+
'initializer_range': 0.02,
|
49 |
+
'rms_norm_eps': 1e-6,
|
50 |
+
'use_cache': True,
|
51 |
+
'tie_word_embeddings': False,
|
52 |
+
},
|
53 |
+
'medium': {
|
54 |
+
'vocab_size': 64256,
|
55 |
+
'hidden_size': 1024,
|
56 |
+
'intermediate_size': 4096,
|
57 |
+
'num_hidden_layers': 24,
|
58 |
+
'num_attention_heads': 16,
|
59 |
+
'max_sequence_length': 2048,
|
60 |
+
'initializer_range': 0.02,
|
61 |
+
'rms_norm_eps': 1e-6,
|
62 |
+
'use_cache': True,
|
63 |
+
'tie_word_embeddings': False,
|
64 |
+
},
|
65 |
+
'large': {
|
66 |
+
'vocab_size': 64256,
|
67 |
+
'hidden_size': 1536,
|
68 |
+
'intermediate_size': 6144,
|
69 |
+
'num_hidden_layers': 24,
|
70 |
+
'num_attention_heads': 16,
|
71 |
+
'max_sequence_length': 2048,
|
72 |
+
'initializer_range': 0.02,
|
73 |
+
'rms_norm_eps': 1e-6,
|
74 |
+
'use_cache': True,
|
75 |
+
'tie_word_embeddings': False,
|
76 |
+
},
|
77 |
+
'xlarge': {
|
78 |
+
'vocab_size': 64256,
|
79 |
+
'hidden_size': 2048,
|
80 |
+
'intermediate_size': 8192,
|
81 |
+
'num_hidden_layers': 24,
|
82 |
+
'num_attention_heads': 32,
|
83 |
+
'max_sequence_length': 2048,
|
84 |
+
'initializer_range': 0.02,
|
85 |
+
'rms_norm_eps': 1e-6,
|
86 |
+
'use_cache': True,
|
87 |
+
'tie_word_embeddings': False,
|
88 |
+
},
|
89 |
+
'1b': {
|
90 |
+
'vocab_size': 64256,
|
91 |
+
'hidden_size': 2048,
|
92 |
+
'intermediate_size': 5504,
|
93 |
+
'num_hidden_layers': 22,
|
94 |
+
'num_attention_heads': 16,
|
95 |
+
'max_sequence_length': 2048,
|
96 |
+
'initializer_range': 0.02,
|
97 |
+
'rms_norm_eps': 1e-6,
|
98 |
+
'use_cache': True,
|
99 |
+
'tie_word_embeddings': False,
|
100 |
+
},
|
101 |
+
'3b': {
|
102 |
+
'vocab_size': 64256,
|
103 |
+
'hidden_size': 3200,
|
104 |
+
'intermediate_size': 8640,
|
105 |
+
'num_hidden_layers': 26,
|
106 |
+
'num_attention_heads': 32,
|
107 |
+
'max_sequence_length': 2048,
|
108 |
+
'initializer_range': 0.02,
|
109 |
+
'rms_norm_eps': 1e-6,
|
110 |
+
'use_cache': True,
|
111 |
+
'tie_word_embeddings': False,
|
112 |
+
},
|
113 |
+
'7b': {
|
114 |
+
'vocab_size': 64256,
|
115 |
+
'hidden_size': 4096,
|
116 |
+
'intermediate_size': 11008,
|
117 |
+
'num_hidden_layers': 32,
|
118 |
+
'num_attention_heads': 32,
|
119 |
+
'max_sequence_length': 2048,
|
120 |
+
'initializer_range': 0.02,
|
121 |
+
'rms_norm_eps': 1e-6,
|
122 |
+
'use_cache': True,
|
123 |
+
'tie_word_embeddings': False,
|
124 |
+
},
|
125 |
+
'13b': {
|
126 |
+
'vocab_size': 64256,
|
127 |
+
'hidden_size': 5120,
|
128 |
+
'intermediate_size': 13824,
|
129 |
+
'num_hidden_layers': 40,
|
130 |
+
'num_attention_heads': 40,
|
131 |
+
'max_sequence_length': 2048,
|
132 |
+
'initializer_range': 0.02,
|
133 |
+
'rms_norm_eps': 1e-6,
|
134 |
+
'use_cache': True,
|
135 |
+
'tie_word_embeddings': False,
|
136 |
+
},
|
137 |
+
'30b': {
|
138 |
+
'vocab_size': 64256,
|
139 |
+
'hidden_size': 6656,
|
140 |
+
'intermediate_size': 17920,
|
141 |
+
'num_hidden_layers': 60,
|
142 |
+
'num_attention_heads': 52,
|
143 |
+
'max_sequence_length': 2048,
|
144 |
+
'initializer_range': 0.02,
|
145 |
+
'rms_norm_eps': 1e-6,
|
146 |
+
'use_cache': True,
|
147 |
+
'tie_word_embeddings': False,
|
148 |
+
},
|
149 |
+
'65b': {
|
150 |
+
'vocab_size': 64256,
|
151 |
+
'hidden_size': 8192,
|
152 |
+
'intermediate_size': 22016,
|
153 |
+
'num_hidden_layers': 80,
|
154 |
+
'num_attention_heads': 64,
|
155 |
+
'max_sequence_length': 2048,
|
156 |
+
'initializer_range': 0.02,
|
157 |
+
'rms_norm_eps': 1e-5,
|
158 |
+
'use_cache': True,
|
159 |
+
'tie_word_embeddings': False,
|
160 |
+
},
|
161 |
+
'debug': { # A small model for debugging
|
162 |
+
'vocab_size': 64256,
|
163 |
+
'hidden_size': 128,
|
164 |
+
'intermediate_size': 256,
|
165 |
+
'num_hidden_layers': 2,
|
166 |
+
'num_attention_heads': 4,
|
167 |
+
'max_sequence_length': 2048,
|
168 |
+
'initializer_range': 0.02,
|
169 |
+
'rms_norm_eps': 1e-6,
|
170 |
+
'use_cache': True,
|
171 |
+
'tie_word_embeddings': False,
|
172 |
+
},
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
class LLaMAConfig(PretrainedConfig):
|
177 |
+
r"""
|
178 |
+
This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA
|
179 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
180 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
181 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
182 |
+
documentation from [`PretrainedConfig`] for more information.
|
183 |
+
Args:
|
184 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
185 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
186 |
+
`inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`].
|
187 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
188 |
+
Dimension of the hidden representations.
|
189 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
190 |
+
Dimension of the MLP representations.
|
191 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
192 |
+
Number of hidden layers in the Transformer encoder.
|
193 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
194 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
195 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
196 |
+
The non-linear activation function (function or string) in the decoder.
|
197 |
+
max_sequence_length (`int`, *optional*, defaults to 2048):
|
198 |
+
Max sequence length for model (for RoPE computation)
|
199 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
200 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
201 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
202 |
+
The epsilon used by the rms normalization layers.
|
203 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
204 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
205 |
+
relevant if `config.is_decoder=True`.
|
206 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
207 |
+
Whether to tie weight embeddings
|
208 |
+
Example:
|
209 |
+
```python
|
210 |
+
>>> from transformers import LLaMAModel, LLaMAConfig
|
211 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
212 |
+
>>> configuration = LLaMAConfig()
|
213 |
+
>>> # Initializing a model from the llama-7b style configuration
|
214 |
+
>>> model = LLaMAModel(configuration)
|
215 |
+
>>> # Accessing the model configuration
|
216 |
+
>>> configuration = model.config
|
217 |
+
```"""
|
218 |
+
model_type = "llama"
|
219 |
+
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
vocab_size=32000,
|
223 |
+
hidden_size=4096,
|
224 |
+
intermediate_size=11008,
|
225 |
+
num_hidden_layers=32,
|
226 |
+
num_attention_heads=32,
|
227 |
+
max_sequence_length=2048,
|
228 |
+
rms_norm_eps=1e-6,
|
229 |
+
initializer_range=0.02,
|
230 |
+
use_cache=True,
|
231 |
+
# pad_token_id=-1,
|
232 |
+
bos_token_id=0,
|
233 |
+
eos_token_id=1,
|
234 |
+
resid_pdrop=0.0,
|
235 |
+
embd_pdrop=0.0,
|
236 |
+
attn_pdrop=0.0,
|
237 |
+
tie_word_embeddings=False,
|
238 |
+
remat_block='nothing_saveable',
|
239 |
+
remat_attention='',
|
240 |
+
remat_mlp='',
|
241 |
+
scan_attention=False,
|
242 |
+
scan_mlp=False,
|
243 |
+
scan_query_chunk_size=1024,
|
244 |
+
scan_key_chunk_size=1024,
|
245 |
+
scan_mlp_chunk_size=1024,
|
246 |
+
fcm_min_ratio=0.0,
|
247 |
+
fcm_max_ratio=0.0,
|
248 |
+
**kwargs,
|
249 |
+
):
|
250 |
+
self.vocab_size = vocab_size
|
251 |
+
self.hidden_size = hidden_size
|
252 |
+
self.initializer_range = initializer_range
|
253 |
+
self.intermediate_size = intermediate_size
|
254 |
+
self.num_hidden_layers = num_hidden_layers
|
255 |
+
self.num_attention_heads = num_attention_heads
|
256 |
+
self.max_sequence_length = max_sequence_length
|
257 |
+
self.rms_norm_eps = rms_norm_eps
|
258 |
+
self.use_cache = use_cache
|
259 |
+
self.resid_pdrop = resid_pdrop
|
260 |
+
self.embd_pdrop = embd_pdrop
|
261 |
+
self.attn_pdrop = attn_pdrop
|
262 |
+
self.remat_block = remat_block
|
263 |
+
self.remat_attention = remat_attention
|
264 |
+
self.remat_mlp = remat_mlp
|
265 |
+
self.scan_attention = scan_attention
|
266 |
+
self.scan_mlp = scan_mlp
|
267 |
+
self.scan_query_chunk_size = scan_query_chunk_size
|
268 |
+
self.scan_key_chunk_size = scan_key_chunk_size
|
269 |
+
self.scan_mlp_chunk_size = scan_mlp_chunk_size
|
270 |
+
self.fcm_min_ratio = fcm_min_ratio
|
271 |
+
self.fcm_max_ratio = fcm_max_ratio
|
272 |
+
super().__init__(
|
273 |
+
# pad_token_id=pad_token_id,
|
274 |
+
bos_token_id=bos_token_id,
|
275 |
+
eos_token_id=eos_token_id,
|
276 |
+
tie_word_embeddings=tie_word_embeddings,
|
277 |
+
**kwargs,
|
278 |
+
)
|
279 |
+
|
280 |
+
@classmethod
|
281 |
+
def get_default_config(cls, updates=None):
|
282 |
+
config = function_args_to_config(cls.__init__)
|
283 |
+
|
284 |
+
if updates is not None:
|
285 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
286 |
+
|
287 |
+
return config
|
288 |
+
|
289 |
+
@staticmethod
|
290 |
+
def get_jax_mesh(axis_dims):
|
291 |
+
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
|
292 |
+
|
293 |
+
@staticmethod
|
294 |
+
def get_partition_rules():
|
295 |
+
""" Parition rules for GPTJ. Note that these rules are orderd, so that
|
296 |
+
the beginning rules match first. It is important to use
|
297 |
+
PartitionSpec() instead of None here because JAX does not treat
|
298 |
+
None as a pytree leaf.
|
299 |
+
"""
|
300 |
+
return (
|
301 |
+
# embeddings
|
302 |
+
("transformer/wte/embedding", PS("mp", "fsdp")),
|
303 |
+
# atention
|
304 |
+
("attention/(wq|wk|wv)/kernel", PS("fsdp", "mp")),
|
305 |
+
("attention/wo/kernel", PS("mp", "fsdp")),
|
306 |
+
# mlp
|
307 |
+
("feed_forward/w1/kernel", PS("fsdp", "mp")),
|
308 |
+
("feed_forward/w2/kernel", PS("mp", "fsdp")),
|
309 |
+
("feed_forward/w3/kernel", PS("fsdp", "mp")),
|
310 |
+
# layer norms
|
311 |
+
("attention_norm/kernel", PS(None)),
|
312 |
+
("ffn_norm/kernel", PS(None)),
|
313 |
+
# output head
|
314 |
+
("transformer/ln_f/kernel", PS(None)),
|
315 |
+
("lm_head/kernel", PS("fsdp", "mp")),
|
316 |
+
('.*', PS(None)),
|
317 |
+
)
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def get_weight_decay_exclusions():
|
321 |
+
return (
|
322 |
+
"attention_norm/kernel",
|
323 |
+
"ffn_norm/kernel",
|
324 |
+
"transformer/ln_f/kernel",
|
325 |
+
)
|
326 |
+
|
327 |
+
@staticmethod
|
328 |
+
def rng_keys():
|
329 |
+
return ('params', 'dropout', 'fcm')
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def get_tokenizer_config(updates=None):
|
333 |
+
config = ConfigDict()
|
334 |
+
config.vocab_file = ''
|
335 |
+
config.pretrained_model_name_or_path = ''
|
336 |
+
config.add_bos_token = False
|
337 |
+
config.add_eos_token = False
|
338 |
+
|
339 |
+
if updates is not None:
|
340 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
341 |
+
return config
|
342 |
+
|
343 |
+
@classmethod
|
344 |
+
def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
|
345 |
+
config = cls.get_tokenizer_config(config)
|
346 |
+
if config.vocab_file == '':
|
347 |
+
assert config.pretrained_model_name_or_path != '', 'vocab_file or pretrained_model_name_or_path must be specified'
|
348 |
+
|
349 |
+
if config.pretrained_model_name_or_path != '':
|
350 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
351 |
+
config.pretrained_model_name_or_path,
|
352 |
+
add_bos_token=config.add_bos_token,
|
353 |
+
add_eos_token=config.add_eos_token,
|
354 |
+
padding_side=padding_side,
|
355 |
+
truncation_side=truncation_side,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
tokenizer = LLaMATokenizer(
|
359 |
+
vocab_file=config.vocab_file,
|
360 |
+
add_bos_token=config.add_bos_token,
|
361 |
+
add_eos_token=config.add_eos_token,
|
362 |
+
padding_side=padding_side,
|
363 |
+
truncation_side=truncation_side,
|
364 |
+
)
|
365 |
+
return tokenizer
|
366 |
+
|
367 |
+
@classmethod
|
368 |
+
def load_config(cls, path):
|
369 |
+
if path in LLAMA_STANDARD_CONFIGS:
|
370 |
+
return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
|
371 |
+
load_type, load_path = path.split('::', 1)
|
372 |
+
if load_type == 'pickle':
|
373 |
+
return cls.from_dict(load_pickle(load_path)['llama_config'])
|
374 |
+
elif load_type == 'json':
|
375 |
+
with open_file(load_path, 'r') as fin:
|
376 |
+
raw_config = fin.read()
|
377 |
+
return cls.from_dict(json.loads(raw_config))
|
378 |
+
else:
|
379 |
+
raise ValueError(f'Unsupported load config type: {load_type}')
|
380 |
+
|
381 |
+
|
382 |
+
remat = nn_partitioning.remat
|
383 |
+
|
384 |
+
logger = logging.get_logger(__name__)
|
385 |
+
|
386 |
+
|
387 |
+
class RMSNorm(nn.Module):
|
388 |
+
dim: int
|
389 |
+
eps: float=1e-6
|
390 |
+
dtype: jnp.dtype=jnp.float32
|
391 |
+
param_dtype: jnp.dtype=jnp.float32
|
392 |
+
|
393 |
+
def setup(self) -> None:
|
394 |
+
self.weight = self.param(
|
395 |
+
'kernel',
|
396 |
+
nn.initializers.ones,
|
397 |
+
(self.dim,),
|
398 |
+
self.param_dtype,
|
399 |
+
)
|
400 |
+
|
401 |
+
def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
|
402 |
+
return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
|
403 |
+
|
404 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
405 |
+
x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
|
406 |
+
output = self._norm(x).astype(self.dtype)
|
407 |
+
weight = jnp.asarray(self.weight, self.dtype)
|
408 |
+
return output * weight
|
409 |
+
|
410 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
|
411 |
+
freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
|
412 |
+
t = np.arange(end) # type: ignore
|
413 |
+
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
|
414 |
+
sin, cos = np.sin(freqs), np.cos(freqs)
|
415 |
+
freqs_cis = np.complex64(cos + 1j * sin)
|
416 |
+
return jnp.asarray(freqs_cis)
|
417 |
+
|
418 |
+
def apply_rotary_emb(
|
419 |
+
xq: jnp.ndarray,
|
420 |
+
xk: jnp.ndarray,
|
421 |
+
freqs_cis: jnp.ndarray,
|
422 |
+
dtype: jnp.dtype=jnp.float32,
|
423 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
424 |
+
|
425 |
+
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
|
426 |
+
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
|
427 |
+
|
428 |
+
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
|
429 |
+
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
|
430 |
+
|
431 |
+
# add head dim
|
432 |
+
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
|
433 |
+
|
434 |
+
xq_out = xq_ * freqs_cis
|
435 |
+
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
|
436 |
+
|
437 |
+
xk_out = xk_ * freqs_cis
|
438 |
+
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
|
439 |
+
|
440 |
+
return xq_out.astype(dtype), xk_out.astype(dtype)
|
441 |
+
|
442 |
+
|
443 |
+
class FlaxLLaMAAttention(nn.Module):
|
444 |
+
config: LLaMAConfig
|
445 |
+
dtype: jnp.dtype=jnp.float32
|
446 |
+
param_dtype: jnp.dtype=jnp.float32
|
447 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
448 |
+
|
449 |
+
def setup(self):
|
450 |
+
config = self.config
|
451 |
+
self.embed_dim = config.hidden_size
|
452 |
+
self.num_heads = config.num_attention_heads
|
453 |
+
self.head_dim = self.embed_dim // self.num_heads
|
454 |
+
|
455 |
+
self.wq = nn.Dense(
|
456 |
+
config.num_attention_heads*self.head_dim,
|
457 |
+
dtype=self.dtype,
|
458 |
+
param_dtype=self.param_dtype,
|
459 |
+
use_bias=False,
|
460 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
461 |
+
precision=self.precision,
|
462 |
+
)
|
463 |
+
self.wk = nn.Dense(
|
464 |
+
config.num_attention_heads*self.head_dim,
|
465 |
+
dtype=self.dtype,
|
466 |
+
param_dtype=self.param_dtype,
|
467 |
+
use_bias=False,
|
468 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
469 |
+
precision=self.precision,
|
470 |
+
)
|
471 |
+
self.wv = nn.Dense(
|
472 |
+
config.num_attention_heads*self.head_dim,
|
473 |
+
dtype=self.dtype,
|
474 |
+
param_dtype=self.param_dtype,
|
475 |
+
use_bias=False,
|
476 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
477 |
+
precision=self.precision,
|
478 |
+
)
|
479 |
+
self.wo = nn.Dense(
|
480 |
+
config.hidden_size,
|
481 |
+
dtype=self.dtype,
|
482 |
+
param_dtype=self.param_dtype,
|
483 |
+
use_bias=False,
|
484 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
485 |
+
precision=self.precision,
|
486 |
+
)
|
487 |
+
|
488 |
+
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
489 |
+
|
490 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
|
491 |
+
|
492 |
+
self.freqs_cis = precompute_freqs_cis(
|
493 |
+
self.head_dim,
|
494 |
+
config.max_sequence_length * 2,
|
495 |
+
dtype=self.dtype,
|
496 |
+
)
|
497 |
+
|
498 |
+
def _split_heads(self, hidden_states):
|
499 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
500 |
+
|
501 |
+
def _merge_heads(self, hidden_states):
|
502 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
503 |
+
|
504 |
+
@nn.compact
|
505 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
506 |
+
"""
|
507 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
508 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
509 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
510 |
+
"""
|
511 |
+
# detect if we're initializing by absence of existing cache data.
|
512 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
513 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
514 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
515 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
516 |
+
|
517 |
+
if is_initialized:
|
518 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
519 |
+
# update key, value caches with our new 1d spatial slices
|
520 |
+
cur_index = cache_index.value
|
521 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
522 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
523 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
524 |
+
cached_key.value = key
|
525 |
+
cached_value.value = value
|
526 |
+
num_updated_cache_vectors = query.shape[1]
|
527 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
528 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
529 |
+
pad_mask = jnp.broadcast_to(
|
530 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
531 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
532 |
+
)
|
533 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
534 |
+
return key, value, attention_mask
|
535 |
+
|
536 |
+
def __call__(
|
537 |
+
self,
|
538 |
+
hidden_states,
|
539 |
+
attention_mask,
|
540 |
+
position_ids,
|
541 |
+
deterministic: bool = True,
|
542 |
+
init_cache: bool = False,
|
543 |
+
output_attentions: bool = False,
|
544 |
+
fcm_mask=None,
|
545 |
+
):
|
546 |
+
xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
|
547 |
+
|
548 |
+
xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
|
549 |
+
xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
|
550 |
+
xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
|
551 |
+
|
552 |
+
xq = self._split_heads(xq)
|
553 |
+
xk = self._split_heads(xk)
|
554 |
+
xv = self._split_heads(xv)
|
555 |
+
|
556 |
+
freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
|
557 |
+
|
558 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
|
559 |
+
|
560 |
+
dropout_rng = None
|
561 |
+
if not deterministic and self.config.attn_pdrop > 0.0:
|
562 |
+
dropout_rng = self.make_rng("dropout")
|
563 |
+
|
564 |
+
if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache):
|
565 |
+
# doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory
|
566 |
+
|
567 |
+
# attention mask without nxn materlization, blockwise_attn will handle the rest
|
568 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
569 |
+
# transform boolean mask into float mask
|
570 |
+
attention_bias = lax.select(
|
571 |
+
attention_mask > 0,
|
572 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
573 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
574 |
+
)
|
575 |
+
attn_weights = None
|
576 |
+
attn_output = blockwise_attn(
|
577 |
+
xq,
|
578 |
+
xk,
|
579 |
+
xv,
|
580 |
+
bias=attention_bias,
|
581 |
+
deterministic=deterministic,
|
582 |
+
dropout_rng=dropout_rng,
|
583 |
+
attn_pdrop=self.config.attn_pdrop,
|
584 |
+
causal=True,
|
585 |
+
query_chunk_size=self.config.scan_query_chunk_size,
|
586 |
+
key_chunk_size=self.config.scan_key_chunk_size,
|
587 |
+
dtype=self.dtype,
|
588 |
+
policy=get_gradient_checkpoint_policy('nothing_saveable'),
|
589 |
+
precision=self.precision,
|
590 |
+
float32_logits=True,
|
591 |
+
prevent_cse=True,
|
592 |
+
)
|
593 |
+
attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
|
594 |
+
else:
|
595 |
+
query_length, key_length = xq.shape[1], xk.shape[1]
|
596 |
+
|
597 |
+
if self.has_variable("cache", "cached_key"):
|
598 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
599 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
600 |
+
causal_mask = lax.dynamic_slice(
|
601 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
602 |
+
)
|
603 |
+
else:
|
604 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
605 |
+
|
606 |
+
batch_size = hidden_states.shape[0]
|
607 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
608 |
+
|
609 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
610 |
+
attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
|
611 |
+
|
612 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
613 |
+
# and cache the keys and values step by step.
|
614 |
+
if self.has_variable("cache", "cached_key") or init_cache:
|
615 |
+
xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
|
616 |
+
|
617 |
+
# transform boolean mask into float mask
|
618 |
+
attention_bias = lax.select(
|
619 |
+
attention_mask > 0,
|
620 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
621 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
622 |
+
)
|
623 |
+
attn_weights = dot_product_attention_weights(
|
624 |
+
xq,
|
625 |
+
xk,
|
626 |
+
bias=attention_bias,
|
627 |
+
dropout_rng=dropout_rng,
|
628 |
+
dropout_rate=self.config.attn_pdrop,
|
629 |
+
deterministic=deterministic,
|
630 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32),
|
631 |
+
precision=self.precision,
|
632 |
+
)
|
633 |
+
attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
|
634 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
|
635 |
+
|
636 |
+
attn_output = self._merge_heads(attn_output)
|
637 |
+
attn_output = self.wo(attn_output)
|
638 |
+
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
639 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
640 |
+
return outputs
|
641 |
+
|
642 |
+
|
643 |
+
class FlaxLLaMAMLP(nn.Module):
|
644 |
+
config: LLaMAConfig
|
645 |
+
dtype: jnp.dtype=jnp.float32
|
646 |
+
param_dtype: jnp.dtype=jnp.float32
|
647 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
648 |
+
|
649 |
+
def setup(self) -> None:
|
650 |
+
config = self.config
|
651 |
+
|
652 |
+
self.w1 = nn.Dense(
|
653 |
+
config.intermediate_size,
|
654 |
+
dtype=self.dtype,
|
655 |
+
param_dtype=self.param_dtype,
|
656 |
+
use_bias=False,
|
657 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
658 |
+
precision=self.precision,
|
659 |
+
)
|
660 |
+
self.w2 = nn.Dense(
|
661 |
+
config.hidden_size,
|
662 |
+
dtype=self.dtype,
|
663 |
+
param_dtype=self.param_dtype,
|
664 |
+
use_bias=False,
|
665 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
666 |
+
precision=self.precision,
|
667 |
+
)
|
668 |
+
self.w3 = nn.Dense(
|
669 |
+
config.intermediate_size,
|
670 |
+
dtype=self.dtype,
|
671 |
+
param_dtype=self.param_dtype,
|
672 |
+
use_bias=False,
|
673 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
674 |
+
precision=self.precision,
|
675 |
+
)
|
676 |
+
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
677 |
+
|
678 |
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
679 |
+
x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
680 |
+
x = self.dropout(x, deterministic=deterministic)
|
681 |
+
return x
|
682 |
+
|
683 |
+
|
684 |
+
class FlaxLLaMABlock(nn.Module):
|
685 |
+
config: LLaMAConfig
|
686 |
+
dtype: jnp.dtype=jnp.float32
|
687 |
+
param_dtype: jnp.dtype=jnp.float32
|
688 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
689 |
+
|
690 |
+
def setup(self) -> None:
|
691 |
+
attention_module = FlaxLLaMAAttention
|
692 |
+
mlp_module = FlaxLLaMAMLP
|
693 |
+
if self.config.remat_attention != '':
|
694 |
+
attention_module = remat(
|
695 |
+
FlaxLLaMAAttention, static_argnums=(3, 4, 5),
|
696 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_attention),
|
697 |
+
prevent_cse=True,
|
698 |
+
)
|
699 |
+
if self.config.remat_mlp != '':
|
700 |
+
mlp_module = remat(
|
701 |
+
FlaxLLaMAMLP, static_argnums=(1,),
|
702 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
|
703 |
+
prevent_cse=True,
|
704 |
+
)
|
705 |
+
|
706 |
+
self.attention = attention_module(
|
707 |
+
self.config,
|
708 |
+
dtype=self.dtype,
|
709 |
+
param_dtype=self.param_dtype,
|
710 |
+
precision=self.precision,
|
711 |
+
)
|
712 |
+
self.feed_forward = mlp_module(
|
713 |
+
self.config,
|
714 |
+
dtype=self.dtype,
|
715 |
+
param_dtype=self.param_dtype,
|
716 |
+
precision=self.precision,
|
717 |
+
)
|
718 |
+
self.attention_norm = RMSNorm(
|
719 |
+
self.config.hidden_size,
|
720 |
+
eps=self.config.rms_norm_eps,
|
721 |
+
dtype=self.dtype,
|
722 |
+
param_dtype=self.param_dtype,
|
723 |
+
)
|
724 |
+
self.ffn_norm = RMSNorm(
|
725 |
+
self.config.hidden_size,
|
726 |
+
eps=self.config.rms_norm_eps,
|
727 |
+
dtype=self.dtype,
|
728 |
+
param_dtype=self.param_dtype,
|
729 |
+
)
|
730 |
+
|
731 |
+
def __call__(
|
732 |
+
self,
|
733 |
+
hidden_states,
|
734 |
+
attention_mask=None,
|
735 |
+
position_ids=None,
|
736 |
+
deterministic: bool = True,
|
737 |
+
init_cache: bool = False,
|
738 |
+
output_attentions: bool = False,
|
739 |
+
fcm_mask: Optional[jnp.ndarray] = None,
|
740 |
+
):
|
741 |
+
attn_outputs = self.attention(
|
742 |
+
self.attention_norm(hidden_states),
|
743 |
+
attention_mask,
|
744 |
+
position_ids,
|
745 |
+
deterministic,
|
746 |
+
init_cache,
|
747 |
+
output_attentions,
|
748 |
+
fcm_mask,
|
749 |
+
)
|
750 |
+
attn_output = attn_outputs[0]
|
751 |
+
hidden_states = hidden_states + attn_output
|
752 |
+
|
753 |
+
feed_forward_input = self.ffn_norm(hidden_states)
|
754 |
+
|
755 |
+
if self.config.scan_mlp:
|
756 |
+
feed_forward_hidden_states = blockwise_ffn(
|
757 |
+
self.feed_forward,
|
758 |
+
feed_forward_input,
|
759 |
+
self.config.scan_mlp_chunk_size,
|
760 |
+
deterministic,
|
761 |
+
)
|
762 |
+
else:
|
763 |
+
feed_forward_hidden_states = self.feed_forward(
|
764 |
+
feed_forward_input,
|
765 |
+
deterministic,
|
766 |
+
)
|
767 |
+
feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp"))
|
768 |
+
|
769 |
+
hidden_states = hidden_states + feed_forward_hidden_states
|
770 |
+
|
771 |
+
return (hidden_states,) + attn_outputs[1:]
|
772 |
+
|
773 |
+
|
774 |
+
class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
|
775 |
+
"""
|
776 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
777 |
+
models.
|
778 |
+
"""
|
779 |
+
|
780 |
+
config_class = LLaMAConfig
|
781 |
+
base_model_prefix = "transformer"
|
782 |
+
module_class: nn.Module = None
|
783 |
+
|
784 |
+
def __init__(
|
785 |
+
self,
|
786 |
+
config: LLaMAConfig,
|
787 |
+
input_shape: Tuple = (1, 1),
|
788 |
+
seed: int = 0,
|
789 |
+
dtype: jnp.dtype = jnp.float32,
|
790 |
+
_do_init: bool = True,
|
791 |
+
**kwargs,
|
792 |
+
):
|
793 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
794 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
795 |
+
|
796 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
797 |
+
# init input tensors
|
798 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
799 |
+
attention_mask = jnp.ones_like(input_ids)
|
800 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
801 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
802 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
803 |
+
|
804 |
+
if self.config.add_cross_attention:
|
805 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
806 |
+
encoder_attention_mask = attention_mask
|
807 |
+
module_init_outputs = self.module.init(
|
808 |
+
rngs,
|
809 |
+
input_ids,
|
810 |
+
attention_mask,
|
811 |
+
position_ids,
|
812 |
+
encoder_hidden_states,
|
813 |
+
encoder_attention_mask,
|
814 |
+
return_dict=False,
|
815 |
+
)
|
816 |
+
else:
|
817 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
818 |
+
|
819 |
+
random_params = module_init_outputs["params"]
|
820 |
+
|
821 |
+
if params is not None:
|
822 |
+
random_params = flatten_dict(unfreeze(random_params))
|
823 |
+
params = flatten_dict(unfreeze(params))
|
824 |
+
for missing_key in self._missing_keys:
|
825 |
+
params[missing_key] = random_params[missing_key]
|
826 |
+
self._missing_keys = set()
|
827 |
+
return freeze(unflatten_dict(params))
|
828 |
+
else:
|
829 |
+
return random_params
|
830 |
+
|
831 |
+
def init_cache(self, batch_size, max_length):
|
832 |
+
r"""
|
833 |
+
Args:
|
834 |
+
batch_size (`int`):
|
835 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
836 |
+
max_length (`int`):
|
837 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
838 |
+
cache.
|
839 |
+
"""
|
840 |
+
# init input variables to retrieve cache
|
841 |
+
input_ids = jnp.ones((batch_size, max_length))
|
842 |
+
attention_mask = jnp.ones_like(input_ids)
|
843 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
844 |
+
|
845 |
+
init_variables = self.module.init(
|
846 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
847 |
+
)
|
848 |
+
return init_variables["cache"]
|
849 |
+
|
850 |
+
@add_start_docstrings_to_model_forward("")
|
851 |
+
def __call__(
|
852 |
+
self,
|
853 |
+
input_ids,
|
854 |
+
attention_mask=None,
|
855 |
+
position_ids=None,
|
856 |
+
params: dict = None,
|
857 |
+
past_key_values: dict = None,
|
858 |
+
dropout_rng: jax.random.PRNGKey = None,
|
859 |
+
train: bool = False,
|
860 |
+
output_attentions: Optional[bool] = None,
|
861 |
+
output_hidden_states: Optional[bool] = None,
|
862 |
+
return_dict: Optional[bool] = None,
|
863 |
+
):
|
864 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
865 |
+
output_hidden_states = (
|
866 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
867 |
+
)
|
868 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
869 |
+
|
870 |
+
batch_size, sequence_length = input_ids.shape
|
871 |
+
|
872 |
+
if position_ids is None:
|
873 |
+
if past_key_values is not None:
|
874 |
+
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
875 |
+
|
876 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
877 |
+
|
878 |
+
if attention_mask is None:
|
879 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
880 |
+
|
881 |
+
# Handle any PRNG if needed
|
882 |
+
rngs = {}
|
883 |
+
if dropout_rng is not None:
|
884 |
+
rngs["dropout"] = dropout_rng
|
885 |
+
|
886 |
+
inputs = {"params": params or self.params}
|
887 |
+
|
888 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
|
889 |
+
if past_key_values:
|
890 |
+
inputs["cache"] = past_key_values
|
891 |
+
mutable = ["cache"]
|
892 |
+
else:
|
893 |
+
mutable = False
|
894 |
+
|
895 |
+
outputs = self.module.apply(
|
896 |
+
inputs,
|
897 |
+
jnp.array(input_ids, dtype="i4"),
|
898 |
+
jnp.array(attention_mask, dtype="i4"),
|
899 |
+
jnp.array(position_ids, dtype="i4"),
|
900 |
+
not train,
|
901 |
+
False,
|
902 |
+
output_attentions,
|
903 |
+
output_hidden_states,
|
904 |
+
return_dict,
|
905 |
+
rngs=rngs,
|
906 |
+
mutable=mutable,
|
907 |
+
)
|
908 |
+
|
909 |
+
# add updated cache to model output
|
910 |
+
if past_key_values is not None and return_dict:
|
911 |
+
outputs, past_key_values = outputs
|
912 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
913 |
+
return outputs
|
914 |
+
elif past_key_values is not None and not return_dict:
|
915 |
+
outputs, past_key_values = outputs
|
916 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
917 |
+
|
918 |
+
return outputs
|
919 |
+
|
920 |
+
|
921 |
+
class FlaxLLaMABlockCollection(nn.Module):
|
922 |
+
config: LLaMAConfig
|
923 |
+
dtype: jnp.dtype = jnp.float32
|
924 |
+
param_dtype: jnp.dtype=jnp.float32
|
925 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
926 |
+
|
927 |
+
def setup(self):
|
928 |
+
block = FlaxLLaMABlock
|
929 |
+
if self.config.remat_block != '':
|
930 |
+
block = remat(
|
931 |
+
FlaxLLaMABlock, static_argnums=(3, 4, 5),
|
932 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_block)
|
933 |
+
)
|
934 |
+
self.blocks = [
|
935 |
+
block(
|
936 |
+
self.config,
|
937 |
+
name=str(i),
|
938 |
+
dtype=self.dtype,
|
939 |
+
param_dtype=self.param_dtype,
|
940 |
+
precision=self.precision
|
941 |
+
) for i in range(self.config.num_hidden_layers)
|
942 |
+
]
|
943 |
+
|
944 |
+
def __call__(
|
945 |
+
self,
|
946 |
+
hidden_states,
|
947 |
+
attention_mask=None,
|
948 |
+
position_ids=None,
|
949 |
+
deterministic: bool = True,
|
950 |
+
init_cache: bool = False,
|
951 |
+
output_attentions: bool = False,
|
952 |
+
output_hidden_states: bool = False,
|
953 |
+
return_dict: bool = True,
|
954 |
+
):
|
955 |
+
all_attentions = () if output_attentions else None
|
956 |
+
all_hidden_states = () if output_hidden_states else None
|
957 |
+
|
958 |
+
if not deterministic and self.config.fcm_max_ratio > 0:
|
959 |
+
# Apply forgetful causal mask
|
960 |
+
batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
|
961 |
+
fcm_ratio = jax.random.uniform(
|
962 |
+
self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
|
963 |
+
minval=self.config.fcm_min_ratio,
|
964 |
+
maxval=self.config.fcm_max_ratio
|
965 |
+
)
|
966 |
+
fcm_mask = jax.random.uniform(
|
967 |
+
self.make_rng('fcm'),
|
968 |
+
shape=(batch_size, 1, 1, seq_length)
|
969 |
+
) > fcm_ratio
|
970 |
+
fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
|
971 |
+
fcm_mask = fcm_mask.astype('bool')
|
972 |
+
else:
|
973 |
+
fcm_mask = None
|
974 |
+
|
975 |
+
for block in self.blocks:
|
976 |
+
if output_hidden_states:
|
977 |
+
all_hidden_states += (hidden_states,)
|
978 |
+
|
979 |
+
layer_outputs = block(
|
980 |
+
hidden_states,
|
981 |
+
attention_mask,
|
982 |
+
position_ids,
|
983 |
+
deterministic,
|
984 |
+
init_cache,
|
985 |
+
output_attentions,
|
986 |
+
fcm_mask,
|
987 |
+
)
|
988 |
+
hidden_states = layer_outputs[0]
|
989 |
+
|
990 |
+
if output_attentions:
|
991 |
+
all_attentions += (layer_outputs[1],)
|
992 |
+
|
993 |
+
# this contains possible `None` values - `FlaxGPTJModule` will filter them out
|
994 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
995 |
+
|
996 |
+
return outputs
|
997 |
+
|
998 |
+
|
999 |
+
class FlaxLLaMAModule(nn.Module):
|
1000 |
+
config: LLaMAConfig
|
1001 |
+
dtype: jnp.dtype = jnp.float32
|
1002 |
+
param_dtype: jnp.dtype=jnp.float32
|
1003 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
1004 |
+
|
1005 |
+
def setup(self):
|
1006 |
+
self.embed_dim = self.config.hidden_size
|
1007 |
+
|
1008 |
+
self.wte = nn.Embed(
|
1009 |
+
self.config.vocab_size,
|
1010 |
+
self.config.hidden_size,
|
1011 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
1012 |
+
dtype=self.dtype,
|
1013 |
+
param_dtype=self.param_dtype,
|
1014 |
+
)
|
1015 |
+
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
1016 |
+
self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
|
1017 |
+
self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
|
1018 |
+
|
1019 |
+
def __call__(
|
1020 |
+
self,
|
1021 |
+
input_ids,
|
1022 |
+
attention_mask,
|
1023 |
+
position_ids,
|
1024 |
+
deterministic=True,
|
1025 |
+
init_cache: bool = False,
|
1026 |
+
output_attentions: bool = False,
|
1027 |
+
output_hidden_states: bool = False,
|
1028 |
+
return_dict: bool = True,
|
1029 |
+
):
|
1030 |
+
input_embeds = self.wte(input_ids.astype("i4"))
|
1031 |
+
|
1032 |
+
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
|
1033 |
+
|
1034 |
+
outputs = self.h(
|
1035 |
+
hidden_states,
|
1036 |
+
attention_mask,
|
1037 |
+
position_ids=position_ids,
|
1038 |
+
deterministic=deterministic,
|
1039 |
+
init_cache=init_cache,
|
1040 |
+
output_attentions=output_attentions,
|
1041 |
+
output_hidden_states=output_hidden_states,
|
1042 |
+
return_dict=return_dict,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
hidden_states = outputs[0]
|
1046 |
+
hidden_states = self.ln_f(hidden_states)
|
1047 |
+
|
1048 |
+
if output_hidden_states:
|
1049 |
+
all_hidden_states = outputs[1] + (hidden_states,)
|
1050 |
+
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
1051 |
+
else:
|
1052 |
+
outputs = (hidden_states,) + outputs[1:]
|
1053 |
+
|
1054 |
+
if not return_dict:
|
1055 |
+
return tuple(v for v in outputs if v is not None)
|
1056 |
+
|
1057 |
+
return FlaxBaseModelOutput(
|
1058 |
+
last_hidden_state=hidden_states,
|
1059 |
+
hidden_states=outputs[1],
|
1060 |
+
attentions=outputs[-1],
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
@add_start_docstrings("", "")
|
1064 |
+
class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel):
|
1065 |
+
module_class = FlaxLLaMAModule
|
1066 |
+
|
1067 |
+
# append_call_sample_docstring(
|
1068 |
+
# FlaxLLaMAModel,
|
1069 |
+
# _TOKENIZER_FOR_DOC,
|
1070 |
+
# _CHECKPOINT_FOR_DOC,
|
1071 |
+
# FlaxCausalLMOutput,
|
1072 |
+
# _CONFIG_FOR_DOC,
|
1073 |
+
# )
|
1074 |
+
|
1075 |
+
class FlaxLLaMAForCausalLMModule(nn.Module):
|
1076 |
+
config: LLaMAConfig
|
1077 |
+
dtype: jnp.dtype = jnp.float32
|
1078 |
+
param_dtype: jnp.dtype=jnp.float32
|
1079 |
+
precision: Optional[Union[jax.lax.Precision, str]]=None
|
1080 |
+
|
1081 |
+
def setup(self):
|
1082 |
+
self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
|
1083 |
+
self.lm_head = nn.Dense(
|
1084 |
+
self.config.vocab_size,
|
1085 |
+
dtype=self.dtype,
|
1086 |
+
param_dtype=self.param_dtype,
|
1087 |
+
use_bias=False,
|
1088 |
+
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
1089 |
+
precision=self.precision,
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
def __call__(
|
1093 |
+
self,
|
1094 |
+
input_ids,
|
1095 |
+
attention_mask=None,
|
1096 |
+
position_ids=None,
|
1097 |
+
deterministic: bool = True,
|
1098 |
+
init_cache: bool = False,
|
1099 |
+
output_attentions: bool = False,
|
1100 |
+
output_hidden_states: bool = False,
|
1101 |
+
return_dict: bool = True,
|
1102 |
+
):
|
1103 |
+
batch_size, seq_length = input_ids.shape
|
1104 |
+
if attention_mask is None:
|
1105 |
+
attention_mask = jnp.ones_like(input_ids)
|
1106 |
+
if position_ids is None:
|
1107 |
+
position_ids = jnp.broadcast_to(
|
1108 |
+
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
1109 |
+
(batch_size, seq_length)
|
1110 |
+
)
|
1111 |
+
outputs = self.transformer(
|
1112 |
+
input_ids,
|
1113 |
+
attention_mask,
|
1114 |
+
position_ids,
|
1115 |
+
deterministic=deterministic,
|
1116 |
+
init_cache=init_cache,
|
1117 |
+
output_attentions=output_attentions,
|
1118 |
+
output_hidden_states=output_hidden_states,
|
1119 |
+
return_dict=return_dict,
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
hidden_states = outputs[0]
|
1123 |
+
|
1124 |
+
if self.config.tie_word_embeddings:
|
1125 |
+
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
|
1126 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
1127 |
+
else:
|
1128 |
+
lm_logits = self.lm_head(hidden_states)
|
1129 |
+
|
1130 |
+
if not return_dict:
|
1131 |
+
return (lm_logits,) + outputs[1:]
|
1132 |
+
|
1133 |
+
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
1134 |
+
|
1135 |
+
|
1136 |
+
@add_start_docstrings("", "")
|
1137 |
+
class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
|
1138 |
+
module_class = FlaxLLaMAForCausalLMModule
|
1139 |
+
|
1140 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
1141 |
+
# initializing the cache
|
1142 |
+
batch_size, seq_length = input_ids.shape
|
1143 |
+
|
1144 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
1145 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
1146 |
+
# But since GPTJ uses a causal mask, those positions are masked anyways.
|
1147 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
1148 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
1149 |
+
if attention_mask is not None:
|
1150 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
1151 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
1152 |
+
else:
|
1153 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
1154 |
+
|
1155 |
+
return {
|
1156 |
+
"past_key_values": past_key_values,
|
1157 |
+
"attention_mask": extended_attention_mask,
|
1158 |
+
"position_ids": position_ids,
|
1159 |
+
}
|
1160 |
+
|
1161 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
1162 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
1163 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
1164 |
+
return model_kwargs
|
1165 |
+
|
1166 |
+
# append_call_sample_docstring(
|
1167 |
+
# FlaxGPTJForCausalLM,
|
1168 |
+
# _TOKENIZER_FOR_DOC,
|
1169 |
+
# _CHECKPOINT_FOR_DOC,
|
1170 |
+
# FlaxCausalLMOutput,
|
1171 |
+
# _CONFIG_FOR_DOC,
|
1172 |
+
# )
|
1173 |
+
|
1174 |
+
|
1175 |
+
|
1176 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
1177 |
+
|
1178 |
+
PRETRAINED_VOCAB_FILES_MAP = {}
|
1179 |
+
|
1180 |
+
|
1181 |
+
class LLaMATokenizer(PreTrainedTokenizer):
|
1182 |
+
"""
|
1183 |
+
Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding.
|
1184 |
+
Args:
|
1185 |
+
vocab_file (`str`):
|
1186 |
+
Path to the vocabulary file.
|
1187 |
+
"""
|
1188 |
+
|
1189 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
1190 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
1191 |
+
model_input_names = ["input_ids", "attention_mask"]
|
1192 |
+
|
1193 |
+
def __init__(
|
1194 |
+
self,
|
1195 |
+
vocab_file,
|
1196 |
+
unk_token="<unk>",
|
1197 |
+
bos_token="<s>",
|
1198 |
+
eos_token="</s>",
|
1199 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
1200 |
+
add_bos_token=False,
|
1201 |
+
add_eos_token=False,
|
1202 |
+
**kwargs,
|
1203 |
+
):
|
1204 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
1205 |
+
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
1206 |
+
self.vocab_file = vocab_file
|
1207 |
+
self.add_bos_token = add_bos_token
|
1208 |
+
self.add_eos_token = add_eos_token
|
1209 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
1210 |
+
|
1211 |
+
with tempfile.NamedTemporaryFile() as tfile:
|
1212 |
+
with open_file(self.vocab_file, 'rb') as fin:
|
1213 |
+
tfile.write(fin.read())
|
1214 |
+
tfile.flush()
|
1215 |
+
tfile.seek(0)
|
1216 |
+
self.sp_model.Load(tfile.name)
|
1217 |
+
""" Initialisation"""
|
1218 |
+
self.add_special_tokens(dict(
|
1219 |
+
unk_token=unk_token,
|
1220 |
+
bos_token=bos_token,
|
1221 |
+
eos_token=eos_token,
|
1222 |
+
))
|
1223 |
+
self.pad_token_id = self.unk_token_id
|
1224 |
+
|
1225 |
+
@property
|
1226 |
+
def vocab_size(self):
|
1227 |
+
"""Returns vocab size"""
|
1228 |
+
return self.sp_model.get_piece_size()
|
1229 |
+
|
1230 |
+
@property
|
1231 |
+
def bos_token_id(self) -> Optional[int]:
|
1232 |
+
return self.sp_model.bos_id()
|
1233 |
+
|
1234 |
+
@property
|
1235 |
+
def eos_token_id(self) -> Optional[int]:
|
1236 |
+
return self.sp_model.eos_id()
|
1237 |
+
|
1238 |
+
def get_vocab(self):
|
1239 |
+
"""Returns vocab as a dict"""
|
1240 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
1241 |
+
vocab.update(self.added_tokens_encoder)
|
1242 |
+
return vocab
|
1243 |
+
|
1244 |
+
def _tokenize(self, text):
|
1245 |
+
"""Returns a tokenized string."""
|
1246 |
+
return self.sp_model.encode(text, out_type=str)
|
1247 |
+
|
1248 |
+
def _convert_token_to_id(self, token):
|
1249 |
+
"""Converts a token (str) in an id using the vocab."""
|
1250 |
+
return self.sp_model.piece_to_id(token)
|
1251 |
+
|
1252 |
+
def _convert_id_to_token(self, index):
|
1253 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
1254 |
+
token = self.sp_model.IdToPiece(index)
|
1255 |
+
return token
|
1256 |
+
|
1257 |
+
def convert_tokens_to_string(self, tokens):
|
1258 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
1259 |
+
current_sub_tokens = []
|
1260 |
+
out_string = ""
|
1261 |
+
prev_is_special = False
|
1262 |
+
for token in tokens:
|
1263 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
1264 |
+
if token in self.all_special_tokens:
|
1265 |
+
if not prev_is_special:
|
1266 |
+
out_string += " "
|
1267 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
1268 |
+
prev_is_special = True
|
1269 |
+
current_sub_tokens = []
|
1270 |
+
else:
|
1271 |
+
current_sub_tokens.append(token)
|
1272 |
+
prev_is_special = False
|
1273 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
1274 |
+
return out_string.strip()
|
1275 |
+
|
1276 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
1277 |
+
"""
|
1278 |
+
Save the vocabulary and special tokens file to a directory.
|
1279 |
+
Args:
|
1280 |
+
save_directory (`str`):
|
1281 |
+
The directory in which to save the vocabulary.
|
1282 |
+
Returns:
|
1283 |
+
`Tuple(str)`: Paths to the files saved.
|
1284 |
+
"""
|
1285 |
+
if not os.path.isdir(save_directory):
|
1286 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
1287 |
+
return
|
1288 |
+
out_vocab_file = os.path.join(
|
1289 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
1290 |
+
)
|
1291 |
+
|
1292 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
1293 |
+
copyfile(self.vocab_file, out_vocab_file)
|
1294 |
+
elif not os.path.isfile(self.vocab_file):
|
1295 |
+
with open(out_vocab_file, "wb") as fi:
|
1296 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
1297 |
+
fi.write(content_spiece_model)
|
1298 |
+
|
1299 |
+
return (out_vocab_file,)
|
1300 |
+
|
1301 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
1302 |
+
if self.add_bos_token:
|
1303 |
+
bos_token_ids = [self.bos_token_id]
|
1304 |
+
else:
|
1305 |
+
bos_token_ids = []
|
1306 |
+
|
1307 |
+
output = bos_token_ids + token_ids_0
|
1308 |
+
|
1309 |
+
if token_ids_1 is not None:
|
1310 |
+
output = output + token_ids_1
|
1311 |
+
|
1312 |
+
if self.add_eos_token:
|
1313 |
+
output = output + [self.eos_token_id]
|
1314 |
+
|
1315 |
+
return output
|
1316 |
+
|
1317 |
+
def get_special_tokens_mask(
|
1318 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
1319 |
+
) -> List[int]:
|
1320 |
+
"""
|
1321 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
1322 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
1323 |
+
Args:
|
1324 |
+
token_ids_0 (`List[int]`):
|
1325 |
+
List of IDs.
|
1326 |
+
token_ids_1 (`List[int]`, *optional*):
|
1327 |
+
Optional second list of IDs for sequence pairs.
|
1328 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
1329 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
1330 |
+
Returns:
|
1331 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
1332 |
+
"""
|
1333 |
+
if already_has_special_tokens:
|
1334 |
+
return super().get_special_tokens_mask(
|
1335 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
1336 |
+
)
|
1337 |
+
|
1338 |
+
if token_ids_1 is None:
|
1339 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
1340 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
1341 |
+
|
1342 |
+
def create_token_type_ids_from_sequences(
|
1343 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
1344 |
+
) -> List[int]:
|
1345 |
+
"""
|
1346 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
1347 |
+
use of token type ids, therefore a list of zeros is returned.
|
1348 |
+
Args:
|
1349 |
+
token_ids_0 (`List[int]`):
|
1350 |
+
List of IDs.
|
1351 |
+
token_ids_1 (`List[int]`, *optional*):
|
1352 |
+
Optional second list of IDs for sequence pairs.
|
1353 |
+
Returns:
|
1354 |
+
`List[int]`: List of zeros.
|
1355 |
+
"""
|
1356 |
+
eos = [self.eos_token_id]
|
1357 |
+
|
1358 |
+
if token_ids_1 is None:
|
1359 |
+
return len(token_ids_0 + eos) * [0]
|
1360 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
EasyLM/models/llama/llama_serve.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import mlxu
|
6 |
+
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
from jax.experimental.pjit import pjit
|
10 |
+
from jax.sharding import PartitionSpec as PS
|
11 |
+
import optax
|
12 |
+
from transformers import GenerationConfig, FlaxLogitsProcessorList
|
13 |
+
|
14 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
15 |
+
from EasyLM.serving import LMServer
|
16 |
+
from EasyLM.jax_utils import (
|
17 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
18 |
+
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
19 |
+
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
20 |
+
)
|
21 |
+
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
|
22 |
+
|
23 |
+
|
24 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
25 |
+
seed=42,
|
26 |
+
initialize_jax_distributed=False,
|
27 |
+
mesh_dim='1,-1,1',
|
28 |
+
dtype='bf16',
|
29 |
+
input_length=1024,
|
30 |
+
seq_length=2048,
|
31 |
+
top_k=50,
|
32 |
+
top_p=1.0,
|
33 |
+
do_sample=True,
|
34 |
+
num_beams=1,
|
35 |
+
add_bos_token=True,
|
36 |
+
load_llama_config='',
|
37 |
+
load_checkpoint='',
|
38 |
+
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
39 |
+
lm_server=LMServer.get_default_config(),
|
40 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def main(argv):
|
45 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
46 |
+
set_random_seed(FLAGS.seed)
|
47 |
+
|
48 |
+
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
49 |
+
FLAGS.tokenizer, truncation_side='left', padding_side='left'
|
50 |
+
)
|
51 |
+
tokenizer = LLaMAConfig.get_tokenizer(
|
52 |
+
FLAGS.tokenizer, truncation_side='right', padding_side='right'
|
53 |
+
)
|
54 |
+
|
55 |
+
with jax.default_device(jax.devices("cpu")[0]):
|
56 |
+
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
57 |
+
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
|
58 |
+
FLAGS.load_checkpoint, disallow_trainstate=True
|
59 |
+
)
|
60 |
+
|
61 |
+
hf_model = FlaxLLaMAForCausalLM(
|
62 |
+
llama_config,
|
63 |
+
input_shape=(1, FLAGS.seq_length),
|
64 |
+
seed=FLAGS.seed,
|
65 |
+
_do_init=False
|
66 |
+
)
|
67 |
+
|
68 |
+
model_ps = match_partition_rules(
|
69 |
+
LLaMAConfig.get_partition_rules(), params
|
70 |
+
)
|
71 |
+
shard_fns, _ = make_shard_and_gather_fns(
|
72 |
+
model_ps, get_float_dtype_by_name(FLAGS.dtype)
|
73 |
+
)
|
74 |
+
|
75 |
+
@partial(
|
76 |
+
pjit,
|
77 |
+
in_shardings=(model_ps, PS(), PS()),
|
78 |
+
out_shardings=(PS(), PS(), PS())
|
79 |
+
)
|
80 |
+
def forward_loglikelihood(params, rng, batch):
|
81 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
82 |
+
rng_generator = JaxRNG(rng)
|
83 |
+
input_tokens = batch['input_tokens']
|
84 |
+
output_tokens = batch['output_tokens']
|
85 |
+
input_mask = batch['input_mask']
|
86 |
+
output_mask = batch['output_mask']
|
87 |
+
|
88 |
+
logits = hf_model.module.apply(
|
89 |
+
params, input_tokens, attention_mask=input_mask,
|
90 |
+
deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
|
91 |
+
).logits
|
92 |
+
# if llama_config.n_real_tokens is not None:
|
93 |
+
# logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
|
94 |
+
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
|
95 |
+
logits, output_tokens
|
96 |
+
)
|
97 |
+
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
|
98 |
+
match_count = jnp.sum(
|
99 |
+
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
|
100 |
+
axis=-1
|
101 |
+
)
|
102 |
+
total = jnp.sum(output_mask, axis=-1)
|
103 |
+
is_greedy = match_count == total
|
104 |
+
return loglikelihood, is_greedy, rng_generator()
|
105 |
+
|
106 |
+
|
107 |
+
@partial(
|
108 |
+
pjit,
|
109 |
+
in_shardings=(model_ps, PS(), PS(), PS()),
|
110 |
+
out_shardings=(PS(), PS())
|
111 |
+
)
|
112 |
+
def forward_generate(params, rng, batch, temperature):
|
113 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
114 |
+
rng_generator = JaxRNG(rng)
|
115 |
+
output = hf_model.generate(
|
116 |
+
batch['input_tokens'],
|
117 |
+
attention_mask=batch['attention_mask'],
|
118 |
+
params=params['params'],
|
119 |
+
prng_key=rng_generator(),
|
120 |
+
logits_processor=FlaxLogitsProcessorList(
|
121 |
+
[FlaxTemperatureLogitsWarper(temperature)]
|
122 |
+
),
|
123 |
+
generation_config=GenerationConfig(
|
124 |
+
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
125 |
+
pad_token_id=tokenizer.eos_token_id,
|
126 |
+
bos_token_id=tokenizer.bos_token_id,
|
127 |
+
eos_token_id=tokenizer.eos_token_id,
|
128 |
+
do_sample=FLAGS.do_sample,
|
129 |
+
num_beams=FLAGS.num_beams,
|
130 |
+
top_k=FLAGS.top_k,
|
131 |
+
top_p=FLAGS.top_p,
|
132 |
+
)
|
133 |
+
).sequences[:, batch['input_tokens'].shape[1]:]
|
134 |
+
return output, rng_generator()
|
135 |
+
|
136 |
+
@partial(
|
137 |
+
pjit,
|
138 |
+
in_shardings=(model_ps, PS(), PS()),
|
139 |
+
out_shardings=(PS(), PS())
|
140 |
+
)
|
141 |
+
def forward_greedy_generate(params, rng, batch):
|
142 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
143 |
+
rng_generator = JaxRNG(rng)
|
144 |
+
output = hf_model.generate(
|
145 |
+
batch['input_tokens'],
|
146 |
+
attention_mask=batch['attention_mask'],
|
147 |
+
params=params['params'],
|
148 |
+
prng_key=rng_generator(),
|
149 |
+
generation_config=GenerationConfig(
|
150 |
+
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
|
151 |
+
pad_token_id=tokenizer.eos_token_id,
|
152 |
+
bos_token_id=tokenizer.bos_token_id,
|
153 |
+
eos_token_id=tokenizer.eos_token_id,
|
154 |
+
do_sample=False,
|
155 |
+
num_beams=1,
|
156 |
+
)
|
157 |
+
).sequences[:, batch['input_tokens'].shape[1]:]
|
158 |
+
return output, rng_generator()
|
159 |
+
|
160 |
+
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
161 |
+
with mesh:
|
162 |
+
params = tree_apply(shard_fns, params)
|
163 |
+
sharded_rng = next_rng()
|
164 |
+
|
165 |
+
class ModelServer(LMServer):
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def loglikelihood(prefix_text, text):
|
169 |
+
nonlocal sharded_rng
|
170 |
+
prefix = prefix_tokenizer(
|
171 |
+
prefix_text,
|
172 |
+
padding='max_length',
|
173 |
+
truncation=True,
|
174 |
+
max_length=FLAGS.input_length,
|
175 |
+
return_tensors='np',
|
176 |
+
)
|
177 |
+
inputs = tokenizer(
|
178 |
+
text,
|
179 |
+
padding='max_length',
|
180 |
+
truncation=True,
|
181 |
+
max_length=FLAGS.seq_length - FLAGS.input_length,
|
182 |
+
return_tensors='np',
|
183 |
+
)
|
184 |
+
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
|
185 |
+
bos_tokens = np.full(
|
186 |
+
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
|
187 |
+
)
|
188 |
+
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
189 |
+
input_mask = np.concatenate(
|
190 |
+
[prefix.attention_mask, inputs.attention_mask], axis=1
|
191 |
+
)
|
192 |
+
if FLAGS.add_bos_token:
|
193 |
+
bos_mask = np.ones_like(input_mask[:, :1])
|
194 |
+
else:
|
195 |
+
bos_mask = np.zeros_like(input_mask[:, :1])
|
196 |
+
|
197 |
+
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
|
198 |
+
output_mask = np.concatenate(
|
199 |
+
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
|
200 |
+
)
|
201 |
+
batch = dict(
|
202 |
+
input_tokens=input_tokens,
|
203 |
+
output_tokens=output_tokens,
|
204 |
+
input_mask=input_mask,
|
205 |
+
output_mask=output_mask,
|
206 |
+
)
|
207 |
+
with mesh:
|
208 |
+
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
209 |
+
params, sharded_rng, batch
|
210 |
+
)
|
211 |
+
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
212 |
+
return loglikelihood, is_greedy
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def loglikelihood_rolling(text):
|
216 |
+
nonlocal sharded_rng
|
217 |
+
inputs = tokenizer(
|
218 |
+
text,
|
219 |
+
padding='longest',
|
220 |
+
truncation=False,
|
221 |
+
max_length=np.iinfo(np.int32).max,
|
222 |
+
return_tensors='np',
|
223 |
+
)
|
224 |
+
batch_size = inputs.input_ids.shape[0]
|
225 |
+
output_tokens = inputs.input_ids
|
226 |
+
attention_mask = inputs.attention_mask
|
227 |
+
|
228 |
+
if output_tokens.shape[1] < FLAGS.seq_length:
|
229 |
+
padding_length = FLAGS.seq_length - output_tokens.shape[1]
|
230 |
+
pad_tokens = np.full(
|
231 |
+
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
|
232 |
+
)
|
233 |
+
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
|
234 |
+
pad_mask = np.zeros(
|
235 |
+
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
|
236 |
+
)
|
237 |
+
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
|
238 |
+
|
239 |
+
bos_tokens = np.full(
|
240 |
+
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
|
241 |
+
)
|
242 |
+
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
|
243 |
+
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
|
244 |
+
total_seq_length = output_tokens.shape[1]
|
245 |
+
|
246 |
+
total_loglikelihood = 0.0
|
247 |
+
total_is_greedy = True
|
248 |
+
# Sliding window
|
249 |
+
for i in range(0, total_seq_length, FLAGS.seq_length):
|
250 |
+
# Last window
|
251 |
+
if i + FLAGS.seq_length > total_seq_length:
|
252 |
+
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
|
253 |
+
last_output_mask[:, :i - total_seq_length] = 0.0
|
254 |
+
|
255 |
+
batch = dict(
|
256 |
+
input_tokens=input_tokens[:, -FLAGS.seq_length:],
|
257 |
+
output_tokens=output_tokens[:, -FLAGS.seq_length:],
|
258 |
+
input_mask=attention_mask[:, -FLAGS.seq_length:],
|
259 |
+
output_mask=last_output_mask,
|
260 |
+
)
|
261 |
+
|
262 |
+
# Normal window
|
263 |
+
else:
|
264 |
+
batch = dict(
|
265 |
+
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
|
266 |
+
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
|
267 |
+
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
268 |
+
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
|
269 |
+
)
|
270 |
+
|
271 |
+
with mesh:
|
272 |
+
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
|
273 |
+
params, sharded_rng, batch
|
274 |
+
)
|
275 |
+
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
|
276 |
+
|
277 |
+
total_loglikelihood += loglikelihood
|
278 |
+
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
|
279 |
+
|
280 |
+
return total_loglikelihood, total_is_greedy
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def generate(text, temperature):
|
284 |
+
nonlocal sharded_rng
|
285 |
+
inputs = prefix_tokenizer(
|
286 |
+
text,
|
287 |
+
padding='max_length',
|
288 |
+
truncation=True,
|
289 |
+
max_length=FLAGS.input_length,
|
290 |
+
return_tensors='np',
|
291 |
+
)
|
292 |
+
input_tokens = inputs.input_ids
|
293 |
+
input_mask = inputs.attention_mask
|
294 |
+
if FLAGS.add_bos_token:
|
295 |
+
input_tokens[:, 0] = tokenizer.bos_token_id
|
296 |
+
input_mask[:, 0] = 1
|
297 |
+
batch = dict(
|
298 |
+
input_tokens=input_tokens,
|
299 |
+
attention_mask=input_mask,
|
300 |
+
)
|
301 |
+
with mesh:
|
302 |
+
output, sharded_rng = forward_generate(
|
303 |
+
params, sharded_rng, batch, temperature
|
304 |
+
)
|
305 |
+
output = jax.device_get(output)
|
306 |
+
output_text = []
|
307 |
+
for text in list(tokenizer.batch_decode(output)):
|
308 |
+
if tokenizer.eos_token in text:
|
309 |
+
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
|
310 |
+
output_text.append(text)
|
311 |
+
|
312 |
+
return output_text
|
313 |
+
|
314 |
+
@staticmethod
|
315 |
+
def greedy_until(prefix_text, until, max_length):
|
316 |
+
nonlocal sharded_rng
|
317 |
+
all_outputs = []
|
318 |
+
for pf, ut in zip(prefix_text, until):
|
319 |
+
if isinstance(ut, str):
|
320 |
+
ut = [ut]
|
321 |
+
total_length = 0
|
322 |
+
total_generated = ''
|
323 |
+
|
324 |
+
while total_length < max_length:
|
325 |
+
pf_tokens = tokenizer(
|
326 |
+
pf,
|
327 |
+
padding=False,
|
328 |
+
truncation=False,
|
329 |
+
max_length=np.iinfo(np.int32).max,
|
330 |
+
return_tensors='np',
|
331 |
+
)
|
332 |
+
input_tokens = pf_tokens.input_ids
|
333 |
+
attention_mask = pf_tokens.attention_mask
|
334 |
+
|
335 |
+
if input_tokens.shape[1] < FLAGS.input_length:
|
336 |
+
extra = FLAGS.input_length - input_tokens.shape[1]
|
337 |
+
pad_tokens = np.full(
|
338 |
+
(1, extra), tokenizer.pad_token_id, dtype=np.int32
|
339 |
+
)
|
340 |
+
input_tokens = np.concatenate(
|
341 |
+
[pad_tokens, input_tokens], axis=1
|
342 |
+
)
|
343 |
+
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
|
344 |
+
attention_mask = np.concatenate(
|
345 |
+
[pad_attention, attention_mask], axis=1
|
346 |
+
)
|
347 |
+
elif input_tokens.shape[1] > FLAGS.input_length:
|
348 |
+
input_tokens = input_tokens[:, -FLAGS.input_length:]
|
349 |
+
attention_mask = attention_mask[:, -FLAGS.input_length:]
|
350 |
+
|
351 |
+
if FLAGS.add_bos_token:
|
352 |
+
input_tokens[:, 0] = tokenizer.bos_token_id
|
353 |
+
attention_mask[:, 0] = 1
|
354 |
+
|
355 |
+
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
|
356 |
+
|
357 |
+
with mesh:
|
358 |
+
output, sharded_rng = forward_greedy_generate(
|
359 |
+
params, sharded_rng, batch
|
360 |
+
)
|
361 |
+
output = jax.device_get(output)
|
362 |
+
|
363 |
+
total_length += output.shape[1]
|
364 |
+
output_text = tokenizer.batch_decode(output)[0]
|
365 |
+
total_generated = total_generated + output_text
|
366 |
+
pf = pf + output_text
|
367 |
+
|
368 |
+
done = False
|
369 |
+
for s in ut:
|
370 |
+
if s in total_generated:
|
371 |
+
total_generated = total_generated.split(s, maxsplit=1)[0]
|
372 |
+
done = True
|
373 |
+
if done:
|
374 |
+
break
|
375 |
+
|
376 |
+
all_outputs.append(total_generated)
|
377 |
+
|
378 |
+
return all_outputs
|
379 |
+
|
380 |
+
|
381 |
+
server = ModelServer(FLAGS.lm_server)
|
382 |
+
server.run()
|
383 |
+
|
384 |
+
|
385 |
+
if __name__ == "__main__":
|
386 |
+
mlxu.run(main)
|
EasyLM/models/llama/llama_train.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from tqdm import tqdm, trange
|
5 |
+
import numpy as np
|
6 |
+
import mlxu
|
7 |
+
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
from jax.experimental.pjit import pjit
|
11 |
+
from jax.sharding import PartitionSpec as PS
|
12 |
+
from flax.training.train_state import TrainState
|
13 |
+
|
14 |
+
from EasyLM.data import DatasetFactory
|
15 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
+
from EasyLM.optimizers import OptimizerFactory
|
17 |
+
from EasyLM.jax_utils import (
|
18 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
19 |
+
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
+
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
+
make_shard_and_gather_fns, with_sharding_constraint,
|
22 |
+
)
|
23 |
+
from EasyLM.models.llama.llama_model import (
|
24 |
+
LLaMAConfig, FlaxLLaMAForCausalLMModule
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
29 |
+
seed=42,
|
30 |
+
mesh_dim='1,-1,1',
|
31 |
+
dtype='fp32',
|
32 |
+
param_dtype='fp32',
|
33 |
+
total_steps=10000,
|
34 |
+
load_llama_config='',
|
35 |
+
update_llama_config='',
|
36 |
+
load_checkpoint='',
|
37 |
+
load_dataset_state='',
|
38 |
+
log_freq=50,
|
39 |
+
save_model_freq=0,
|
40 |
+
save_milestone_freq=0,
|
41 |
+
eval_freq=0,
|
42 |
+
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
43 |
+
train_dataset=DatasetFactory.get_default_config(),
|
44 |
+
eval_dataset=DatasetFactory.get_default_config(),
|
45 |
+
optimizer=OptimizerFactory.get_default_config(),
|
46 |
+
checkpointer=StreamingCheckpointer.get_default_config(),
|
47 |
+
llama=LLaMAConfig.get_default_config(),
|
48 |
+
logger=mlxu.WandBLogger.get_default_config(),
|
49 |
+
log_all_worker=False,
|
50 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def main(argv):
|
55 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
56 |
+
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
57 |
+
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
58 |
+
logger = mlxu.WandBLogger(
|
59 |
+
config=FLAGS.logger,
|
60 |
+
variant=variant,
|
61 |
+
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
62 |
+
)
|
63 |
+
set_random_seed(FLAGS.seed)
|
64 |
+
|
65 |
+
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
66 |
+
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
67 |
+
if FLAGS.load_dataset_state != '':
|
68 |
+
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
69 |
+
|
70 |
+
if FLAGS.eval_freq > 0:
|
71 |
+
eval_dataset = DatasetFactory.load_dataset(
|
72 |
+
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
|
73 |
+
)
|
74 |
+
|
75 |
+
seq_length = dataset.seq_length
|
76 |
+
|
77 |
+
if FLAGS.load_llama_config != '':
|
78 |
+
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
|
79 |
+
else:
|
80 |
+
llama_config = LLaMAConfig(**FLAGS.llama)
|
81 |
+
|
82 |
+
if FLAGS.update_llama_config != '':
|
83 |
+
llama_config.update(dict(eval(FLAGS.update_llama_config)))
|
84 |
+
|
85 |
+
llama_config.update(dict(
|
86 |
+
bos_token_id=dataset.tokenizer.bos_token_id,
|
87 |
+
eos_token_id=dataset.tokenizer.eos_token_id,
|
88 |
+
))
|
89 |
+
if llama_config.vocab_size < dataset.vocab_size:
|
90 |
+
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
|
91 |
+
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
92 |
+
|
93 |
+
model = FlaxLLaMAForCausalLMModule(
|
94 |
+
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
|
95 |
+
)
|
96 |
+
|
97 |
+
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
98 |
+
FLAGS.optimizer,
|
99 |
+
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
|
100 |
+
)
|
101 |
+
|
102 |
+
def create_trainstate_from_params(params):
|
103 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
104 |
+
|
105 |
+
def init_fn(rng):
|
106 |
+
rng_generator = JaxRNG(rng)
|
107 |
+
params = model.init(
|
108 |
+
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
109 |
+
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
110 |
+
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
111 |
+
rngs=rng_generator(llama_config.rng_keys()),
|
112 |
+
)
|
113 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
114 |
+
|
115 |
+
def train_step(train_state, rng, batch):
|
116 |
+
rng_generator = JaxRNG(rng)
|
117 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
118 |
+
def loss_and_accuracy(params):
|
119 |
+
logits = model.apply(
|
120 |
+
params, batch['input_tokens'], deterministic=False,
|
121 |
+
rngs=rng_generator(llama_config.rng_keys()),
|
122 |
+
).logits
|
123 |
+
return cross_entropy_loss_and_accuracy(
|
124 |
+
logits, batch['target_tokens'], batch['loss_masks']
|
125 |
+
)
|
126 |
+
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
127 |
+
(loss, accuracy), grads = grad_fn(train_state.params)
|
128 |
+
train_state = train_state.apply_gradients(grads=grads)
|
129 |
+
metrics = dict(
|
130 |
+
loss=loss,
|
131 |
+
accuracy=accuracy,
|
132 |
+
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
133 |
+
gradient_norm=global_norm(grads),
|
134 |
+
param_norm=global_norm(train_state.params),
|
135 |
+
)
|
136 |
+
return train_state, rng_generator(), metrics
|
137 |
+
|
138 |
+
def eval_step(train_state, rng, batch):
|
139 |
+
rng_generator = JaxRNG(rng)
|
140 |
+
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
|
141 |
+
logits = model.apply(
|
142 |
+
train_state.params, batch['input_tokens'], deterministic=True,
|
143 |
+
rngs=rng_generator(llama_config.rng_keys()),
|
144 |
+
).logits
|
145 |
+
loss, accuracy = cross_entropy_loss_and_accuracy(
|
146 |
+
logits, batch['target_tokens'], batch['loss_masks']
|
147 |
+
)
|
148 |
+
metrics = dict(
|
149 |
+
eval_loss=loss,
|
150 |
+
eval_accuracy=accuracy,
|
151 |
+
)
|
152 |
+
return rng_generator(), metrics
|
153 |
+
|
154 |
+
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
155 |
+
train_state_partition = match_partition_rules(
|
156 |
+
LLaMAConfig.get_partition_rules(), train_state_shapes
|
157 |
+
)
|
158 |
+
|
159 |
+
shard_fns, gather_fns = make_shard_and_gather_fns(
|
160 |
+
train_state_partition, train_state_shapes
|
161 |
+
)
|
162 |
+
checkpointer = StreamingCheckpointer(
|
163 |
+
FLAGS.checkpointer, logger.output_dir,
|
164 |
+
enable=jax.process_index() == 0,
|
165 |
+
)
|
166 |
+
|
167 |
+
sharded_init_fn = pjit(
|
168 |
+
init_fn,
|
169 |
+
in_shardings=PS(),
|
170 |
+
out_shardings=train_state_partition
|
171 |
+
)
|
172 |
+
|
173 |
+
sharded_create_trainstate_from_params = pjit(
|
174 |
+
create_trainstate_from_params,
|
175 |
+
in_shardings=(train_state_partition.params, ),
|
176 |
+
out_shardings=train_state_partition,
|
177 |
+
donate_argnums=(0, ),
|
178 |
+
)
|
179 |
+
|
180 |
+
sharded_train_step = pjit(
|
181 |
+
train_step,
|
182 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
183 |
+
out_shardings=(train_state_partition, PS(), PS()),
|
184 |
+
donate_argnums=(0, 1),
|
185 |
+
)
|
186 |
+
|
187 |
+
sharded_eval_step = pjit(
|
188 |
+
eval_step,
|
189 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
190 |
+
out_shardings=(PS(), PS()),
|
191 |
+
donate_argnums=(1,),
|
192 |
+
)
|
193 |
+
|
194 |
+
def save_checkpoint(train_state, milestone=False):
|
195 |
+
step = int(jax.device_get(train_state.step))
|
196 |
+
metadata = dict(
|
197 |
+
step=step,
|
198 |
+
variant=variant,
|
199 |
+
flags=flags_config_dict,
|
200 |
+
llama_config=llama_config.to_dict(),
|
201 |
+
)
|
202 |
+
checkpointer.save_all(
|
203 |
+
train_state=train_state,
|
204 |
+
gather_fns=gather_fns,
|
205 |
+
metadata=metadata,
|
206 |
+
dataset=dataset.get_state_dict(),
|
207 |
+
milestone=milestone,
|
208 |
+
)
|
209 |
+
|
210 |
+
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
|
211 |
+
with mesh:
|
212 |
+
train_state, restored_params = None, None
|
213 |
+
if FLAGS.load_checkpoint != '':
|
214 |
+
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
215 |
+
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
216 |
+
)
|
217 |
+
|
218 |
+
if train_state is None and restored_params is None:
|
219 |
+
# Initialize from scratch
|
220 |
+
train_state = sharded_init_fn(next_rng())
|
221 |
+
elif train_state is None and restored_params is not None:
|
222 |
+
# Restore from params but initialize train_state
|
223 |
+
train_state = sharded_create_trainstate_from_params(restored_params)
|
224 |
+
del restored_params
|
225 |
+
|
226 |
+
start_step = int(jax.device_get(train_state.step))
|
227 |
+
|
228 |
+
if FLAGS.save_model_freq > 0:
|
229 |
+
save_checkpoint(train_state)
|
230 |
+
|
231 |
+
sharded_rng = next_rng()
|
232 |
+
|
233 |
+
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
234 |
+
|
235 |
+
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
236 |
+
train_state, sharded_rng, metrics = sharded_train_step(
|
237 |
+
train_state, sharded_rng, batch
|
238 |
+
)
|
239 |
+
|
240 |
+
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
|
241 |
+
eval_metric_list = []
|
242 |
+
eval_iterator = iter(eval_dataset)
|
243 |
+
for eval_batch, _ in eval_iterator:
|
244 |
+
sharded_rng, eval_metrics = sharded_eval_step(
|
245 |
+
train_state, sharded_rng, eval_batch
|
246 |
+
)
|
247 |
+
eval_metric_list.append(eval_metrics)
|
248 |
+
metrics.update(average_metrics(eval_metric_list))
|
249 |
+
|
250 |
+
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
251 |
+
log_metrics = {"step": step + 1}
|
252 |
+
log_metrics.update(metrics)
|
253 |
+
log_metrics.update(dataset_metrics)
|
254 |
+
log_metrics = jax.device_get(log_metrics)
|
255 |
+
logger.log(log_metrics)
|
256 |
+
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
257 |
+
|
258 |
+
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
259 |
+
save_checkpoint(train_state, milestone=True)
|
260 |
+
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
261 |
+
save_checkpoint(train_state)
|
262 |
+
|
263 |
+
if FLAGS.save_model_freq > 0:
|
264 |
+
save_checkpoint(train_state)
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == "__main__":
|
268 |
+
mlxu.run(main)
|
EasyLM/models/roberta/__init__.py
ADDED
File without changes
|
EasyLM/models/roberta/roberta_model.py
ADDED
@@ -0,0 +1,1694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Modifications copyright 2022 Xinyang Geng
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from typing import Callable, Optional, Tuple
|
17 |
+
from collections import OrderedDict
|
18 |
+
from typing import Mapping
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
import flax.linen as nn
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
26 |
+
from flax.linen import combine_masks, make_causal_mask
|
27 |
+
from flax.linen import partitioning as nn_partitioning
|
28 |
+
from flax.linen.attention import dot_product_attention_weights
|
29 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
30 |
+
from jax import lax
|
31 |
+
from jax.sharding import PartitionSpec
|
32 |
+
|
33 |
+
from transformers.configuration_utils import PretrainedConfig
|
34 |
+
from transformers.modeling_flax_outputs import (
|
35 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
36 |
+
FlaxBaseModelOutputWithPooling,
|
37 |
+
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
38 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
39 |
+
FlaxMaskedLMOutput,
|
40 |
+
FlaxMultipleChoiceModelOutput,
|
41 |
+
FlaxQuestionAnsweringModelOutput,
|
42 |
+
FlaxSequenceClassifierOutput,
|
43 |
+
FlaxTokenClassifierOutput,
|
44 |
+
)
|
45 |
+
from transformers.modeling_flax_utils import (
|
46 |
+
ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring,
|
47 |
+
overwrite_call_docstring
|
48 |
+
)
|
49 |
+
from transformers.utils import (
|
50 |
+
add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
51 |
+
)
|
52 |
+
from transformers import AutoTokenizer
|
53 |
+
|
54 |
+
from ml_collections import ConfigDict
|
55 |
+
from ml_collections.config_dict import config_dict
|
56 |
+
from mlxu import function_args_to_config, load_pickle
|
57 |
+
|
58 |
+
from EasyLM.jax_utils import with_sharding_constraint, get_jax_mesh
|
59 |
+
|
60 |
+
|
61 |
+
"""
|
62 |
+
The follow code is taken from
|
63 |
+
transformers/src/transformers/models/roberta/configuration_roberta.py
|
64 |
+
and modified to work with EasyLM.
|
65 |
+
"""
|
66 |
+
|
67 |
+
|
68 |
+
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
69 |
+
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json",
|
70 |
+
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json",
|
71 |
+
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json",
|
72 |
+
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json",
|
73 |
+
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json",
|
74 |
+
"roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json",
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
class RobertaConfig(PretrainedConfig):
|
79 |
+
r"""
|
80 |
+
This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
|
81 |
+
used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
|
82 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
|
83 |
+
[roberta-base](https://huggingface.co/roberta-base) architecture.
|
84 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
85 |
+
documentation from [`PretrainedConfig`] for more information.
|
86 |
+
Args:
|
87 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
88 |
+
Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
|
89 |
+
`inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
|
90 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
91 |
+
Dimensionality of the encoder layers and the pooler layer.
|
92 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
93 |
+
Number of hidden layers in the Transformer encoder.
|
94 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
95 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
96 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
97 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
98 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
99 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
100 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
101 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
102 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
103 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
104 |
+
The dropout ratio for the attention probabilities.
|
105 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
106 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
107 |
+
just in case (e.g., 512 or 1024 or 2048).
|
108 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
109 |
+
The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
|
110 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
111 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
112 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
113 |
+
The epsilon used by the layer normalization layers.
|
114 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
115 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
116 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
117 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
118 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
119 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
120 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
121 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
122 |
+
relevant if `config.is_decoder=True`.
|
123 |
+
classifier_dropout (`float`, *optional*):
|
124 |
+
The dropout ratio for the classification head.
|
125 |
+
Examples:
|
126 |
+
```python
|
127 |
+
>>> from transformers import RobertaConfig, RobertaModel
|
128 |
+
>>> # Initializing a RoBERTa configuration
|
129 |
+
>>> configuration = RobertaConfig()
|
130 |
+
>>> # Initializing a model (with random weights) from the configuration
|
131 |
+
>>> model = RobertaModel(configuration)
|
132 |
+
>>> # Accessing the model configuration
|
133 |
+
>>> configuration = model.config
|
134 |
+
```"""
|
135 |
+
model_type = "roberta"
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
vocab_size=50265,
|
140 |
+
hidden_size=768,
|
141 |
+
num_hidden_layers=12,
|
142 |
+
num_attention_heads=12,
|
143 |
+
intermediate_size=3072,
|
144 |
+
hidden_act="gelu",
|
145 |
+
hidden_dropout_prob=0.1,
|
146 |
+
attention_probs_dropout_prob=0.1,
|
147 |
+
max_position_embeddings=514,
|
148 |
+
type_vocab_size=1,
|
149 |
+
initializer_range=0.02,
|
150 |
+
layer_norm_eps=1e-5,
|
151 |
+
pad_token_id=1,
|
152 |
+
bos_token_id=0,
|
153 |
+
eos_token_id=2,
|
154 |
+
position_embedding_type="absolute",
|
155 |
+
use_cache=True,
|
156 |
+
classifier_dropout=None,
|
157 |
+
**kwargs
|
158 |
+
):
|
159 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
160 |
+
|
161 |
+
self.vocab_size = vocab_size
|
162 |
+
self.hidden_size = hidden_size
|
163 |
+
self.num_hidden_layers = num_hidden_layers
|
164 |
+
self.num_attention_heads = num_attention_heads
|
165 |
+
self.hidden_act = hidden_act
|
166 |
+
self.intermediate_size = intermediate_size
|
167 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
168 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
169 |
+
self.max_position_embeddings = max_position_embeddings
|
170 |
+
self.type_vocab_size = type_vocab_size
|
171 |
+
self.initializer_range = initializer_range
|
172 |
+
self.layer_norm_eps = layer_norm_eps
|
173 |
+
self.position_embedding_type = position_embedding_type
|
174 |
+
self.use_cache = use_cache
|
175 |
+
self.classifier_dropout = classifier_dropout
|
176 |
+
|
177 |
+
@classmethod
|
178 |
+
def get_default_config(cls, updates=None):
|
179 |
+
none_arg_types = dict(
|
180 |
+
classifier_dropout=float,
|
181 |
+
)
|
182 |
+
config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
|
183 |
+
config.tie_word_embeddings = True
|
184 |
+
|
185 |
+
if updates is not None:
|
186 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
187 |
+
|
188 |
+
return config
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def get_jax_mesh(axis_dims):
|
192 |
+
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def get_partition_rules():
|
196 |
+
""" Parition rules for Roberta model. """
|
197 |
+
return (
|
198 |
+
('embeddings/(position_embeddings|token_type_embeddings)/embedding', PartitionSpec()),
|
199 |
+
('embeddings/word_embeddings/embedding', PartitionSpec()),
|
200 |
+
('attention/self/(key|query|value)/kernel', PartitionSpec('fsdp', 'mp')),
|
201 |
+
('attention/self/(key|query|value)/bias', PartitionSpec()),
|
202 |
+
('attention/output/dense/kernel', PartitionSpec('mp', 'fsdp')),
|
203 |
+
('attention/output/dense/bias', PartitionSpec()),
|
204 |
+
('(LayerNorm|layer_norm)/(bias|scale)', PartitionSpec()),
|
205 |
+
('intermediate/dense/kernel', PartitionSpec('fsdp', 'mp')),
|
206 |
+
('intermediate/dense/bias', PartitionSpec('mp')),
|
207 |
+
('output/dense/kernel', PartitionSpec('mp', 'fsdp')),
|
208 |
+
('output/dense/bias', PartitionSpec()),
|
209 |
+
('lm_head/dense/kernel', PartitionSpec()),
|
210 |
+
('lm_head/dense/bias', PartitionSpec()),
|
211 |
+
('lm_head/decoder/kernel', PartitionSpec('fsdp', 'mp')),
|
212 |
+
('lm_head/decoder/bias', PartitionSpec('mp')),
|
213 |
+
('.*', PartitionSpec()),
|
214 |
+
)
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def get_weight_decay_exclusions():
|
218 |
+
return ('bias', 'LayerNorm/scale', 'layer_norm/scale')
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def rng_keys():
|
222 |
+
return ('params', 'dropout')
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def get_tokenizer_config(updates=None):
|
226 |
+
config = ConfigDict()
|
227 |
+
config.name = 'roberta-base'
|
228 |
+
|
229 |
+
if updates is not None:
|
230 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
231 |
+
|
232 |
+
return config
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def get_tokenizer(cls, config):
|
236 |
+
config = cls.get_tokenizer_config(config)
|
237 |
+
return AutoTokenizer.from_pretrained(
|
238 |
+
config.name,
|
239 |
+
)
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def load_pretrained(name):
|
243 |
+
with jax.default_device(jax.devices("cpu")[0]):
|
244 |
+
params = FlaxRobertaForMaskedLM.from_pretrained(name, _do_init=False)[1]
|
245 |
+
params = freeze({'params': params})
|
246 |
+
return params
|
247 |
+
|
248 |
+
@classmethod
|
249 |
+
def load_config(cls, path):
|
250 |
+
load_type, load_path = path.split('::', 1)
|
251 |
+
if load_type == 'pickle':
|
252 |
+
return cls.from_dict(load_pickle(load_path)['roberta_config'])
|
253 |
+
elif load_type == 'huggingface':
|
254 |
+
return cls.from_pretrained(load_path)
|
255 |
+
else:
|
256 |
+
raise ValueError(f'Unsupported load config type: {load_type}')
|
257 |
+
|
258 |
+
|
259 |
+
"""
|
260 |
+
The follow code is taken from
|
261 |
+
transformers/src/transformers/models/roberta/modeling_flax_roberta.py
|
262 |
+
and modified to work with EasyLM.
|
263 |
+
"""
|
264 |
+
|
265 |
+
|
266 |
+
logger = logging.get_logger(__name__)
|
267 |
+
|
268 |
+
_CHECKPOINT_FOR_DOC = "roberta-base"
|
269 |
+
_CONFIG_FOR_DOC = "RobertaConfig"
|
270 |
+
|
271 |
+
remat = nn_partitioning.remat
|
272 |
+
|
273 |
+
|
274 |
+
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
275 |
+
"""
|
276 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
277 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
278 |
+
Args:
|
279 |
+
input_ids: jnp.ndarray
|
280 |
+
padding_idx: int
|
281 |
+
Returns: jnp.ndarray
|
282 |
+
"""
|
283 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
284 |
+
mask = (input_ids != padding_idx).astype("i4")
|
285 |
+
|
286 |
+
if mask.ndim > 2:
|
287 |
+
mask = mask.reshape((-1, mask.shape[-1]))
|
288 |
+
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
289 |
+
incremental_indices = incremental_indices.reshape(input_ids.shape)
|
290 |
+
else:
|
291 |
+
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
292 |
+
|
293 |
+
return incremental_indices.astype("i4") + padding_idx
|
294 |
+
|
295 |
+
|
296 |
+
ROBERTA_START_DOCSTRING = r"""
|
297 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
298 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
299 |
+
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
300 |
+
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
301 |
+
general usage and behavior.
|
302 |
+
Finally, this model supports inherent JAX features such as:
|
303 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
304 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
305 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
306 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
307 |
+
Parameters:
|
308 |
+
config ([`RobertaConfig`]): Model configuration class with all the parameters of the
|
309 |
+
model. Initializing with a config file does not load the weights associated with the model, only the
|
310 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
311 |
+
"""
|
312 |
+
|
313 |
+
ROBERTA_INPUTS_DOCSTRING = r"""
|
314 |
+
Args:
|
315 |
+
input_ids (`numpy.ndarray` of shape `({0})`):
|
316 |
+
Indices of input sequence tokens in the vocabulary.
|
317 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
318 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
319 |
+
[What are input IDs?](../glossary#input-ids)
|
320 |
+
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
321 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
322 |
+
- 1 for tokens that are **not masked**,
|
323 |
+
- 0 for tokens that are **masked**.
|
324 |
+
[What are attention masks?](../glossary#attention-mask)
|
325 |
+
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
326 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
327 |
+
1]`:
|
328 |
+
- 0 corresponds to a *sentence A* token,
|
329 |
+
- 1 corresponds to a *sentence B* token.
|
330 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
331 |
+
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
332 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
333 |
+
config.max_position_embeddings - 1]`.
|
334 |
+
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
335 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
336 |
+
- 1 indicates the head is **not masked**,
|
337 |
+
- 0 indicates the head is **masked**.
|
338 |
+
return_dict (`bool`, *optional*):
|
339 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
340 |
+
"""
|
341 |
+
|
342 |
+
|
343 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
344 |
+
class FlaxRobertaEmbeddings(nn.Module):
|
345 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
346 |
+
|
347 |
+
config: RobertaConfig
|
348 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
349 |
+
|
350 |
+
def setup(self):
|
351 |
+
self.word_embeddings = nn.Embed(
|
352 |
+
self.config.vocab_size,
|
353 |
+
self.config.hidden_size,
|
354 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
355 |
+
dtype=self.dtype,
|
356 |
+
)
|
357 |
+
self.position_embeddings = nn.Embed(
|
358 |
+
self.config.max_position_embeddings,
|
359 |
+
self.config.hidden_size,
|
360 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
361 |
+
dtype=self.dtype,
|
362 |
+
)
|
363 |
+
self.token_type_embeddings = nn.Embed(
|
364 |
+
self.config.type_vocab_size,
|
365 |
+
self.config.hidden_size,
|
366 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
367 |
+
dtype=self.dtype,
|
368 |
+
)
|
369 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
370 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
371 |
+
|
372 |
+
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
373 |
+
# Embed
|
374 |
+
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
375 |
+
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
376 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
377 |
+
|
378 |
+
# Sum all embeddings
|
379 |
+
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
380 |
+
|
381 |
+
# Layer Norm
|
382 |
+
hidden_states = self.LayerNorm(hidden_states)
|
383 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
384 |
+
return hidden_states
|
385 |
+
|
386 |
+
|
387 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
|
388 |
+
class FlaxRobertaSelfAttention(nn.Module):
|
389 |
+
config: RobertaConfig
|
390 |
+
causal: bool = False
|
391 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
392 |
+
|
393 |
+
def setup(self):
|
394 |
+
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
395 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
396 |
+
raise ValueError(
|
397 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
398 |
+
" : {self.config.num_attention_heads}"
|
399 |
+
)
|
400 |
+
|
401 |
+
self.query = nn.Dense(
|
402 |
+
self.config.hidden_size,
|
403 |
+
dtype=self.dtype,
|
404 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
405 |
+
)
|
406 |
+
self.key = nn.Dense(
|
407 |
+
self.config.hidden_size,
|
408 |
+
dtype=self.dtype,
|
409 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
410 |
+
)
|
411 |
+
self.value = nn.Dense(
|
412 |
+
self.config.hidden_size,
|
413 |
+
dtype=self.dtype,
|
414 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
415 |
+
)
|
416 |
+
|
417 |
+
if self.causal:
|
418 |
+
self.causal_mask = make_causal_mask(
|
419 |
+
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
420 |
+
)
|
421 |
+
|
422 |
+
def _split_heads(self, hidden_states):
|
423 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
424 |
+
|
425 |
+
def _merge_heads(self, hidden_states):
|
426 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
427 |
+
|
428 |
+
@nn.compact
|
429 |
+
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
430 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
431 |
+
"""
|
432 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
433 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
434 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
435 |
+
"""
|
436 |
+
# detect if we're initializing by absence of existing cache data.
|
437 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
438 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
439 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
440 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
441 |
+
|
442 |
+
if is_initialized:
|
443 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
444 |
+
# update key, value caches with our new 1d spatial slices
|
445 |
+
cur_index = cache_index.value
|
446 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
447 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
448 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
449 |
+
cached_key.value = key
|
450 |
+
cached_value.value = value
|
451 |
+
num_updated_cache_vectors = query.shape[1]
|
452 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
453 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
454 |
+
pad_mask = jnp.broadcast_to(
|
455 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
456 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
457 |
+
)
|
458 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
459 |
+
return key, value, attention_mask
|
460 |
+
|
461 |
+
def __call__(
|
462 |
+
self,
|
463 |
+
hidden_states,
|
464 |
+
attention_mask,
|
465 |
+
layer_head_mask,
|
466 |
+
key_value_states: Optional[jnp.array] = None,
|
467 |
+
init_cache: bool = False,
|
468 |
+
deterministic=True,
|
469 |
+
output_attentions: bool = False,
|
470 |
+
):
|
471 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
472 |
+
# for the decoder
|
473 |
+
is_cross_attention = key_value_states is not None
|
474 |
+
batch_size = hidden_states.shape[0]
|
475 |
+
|
476 |
+
# get query proj
|
477 |
+
query_states = self.query(hidden_states)
|
478 |
+
# get key, value proj
|
479 |
+
if is_cross_attention:
|
480 |
+
# cross_attentions
|
481 |
+
key_states = self.key(key_value_states)
|
482 |
+
value_states = self.value(key_value_states)
|
483 |
+
else:
|
484 |
+
# self_attention
|
485 |
+
key_states = self.key(hidden_states)
|
486 |
+
value_states = self.value(hidden_states)
|
487 |
+
|
488 |
+
query_states = self._split_heads(query_states)
|
489 |
+
key_states = self._split_heads(key_states)
|
490 |
+
value_states = self._split_heads(value_states)
|
491 |
+
|
492 |
+
# handle cache prepare causal attention mask
|
493 |
+
if self.causal:
|
494 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
495 |
+
if self.has_variable("cache", "cached_key"):
|
496 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
497 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
498 |
+
causal_mask = lax.dynamic_slice(
|
499 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
500 |
+
)
|
501 |
+
else:
|
502 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
503 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
504 |
+
|
505 |
+
# combine masks if needed
|
506 |
+
if attention_mask is not None and self.causal:
|
507 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
508 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
509 |
+
elif self.causal:
|
510 |
+
attention_mask = causal_mask
|
511 |
+
elif attention_mask is not None:
|
512 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
513 |
+
|
514 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
515 |
+
# and cache the keys and values step by step.
|
516 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
517 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
518 |
+
key_states, value_states, query_states, attention_mask
|
519 |
+
)
|
520 |
+
|
521 |
+
# Convert the boolean attention mask to an attention bias.
|
522 |
+
if attention_mask is not None:
|
523 |
+
# attention mask in the form of attention bias
|
524 |
+
attention_bias = lax.select(
|
525 |
+
attention_mask > 0,
|
526 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
527 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
528 |
+
)
|
529 |
+
else:
|
530 |
+
attention_bias = None
|
531 |
+
|
532 |
+
dropout_rng = None
|
533 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
534 |
+
dropout_rng = self.make_rng("dropout")
|
535 |
+
|
536 |
+
attn_weights = dot_product_attention_weights(
|
537 |
+
query_states,
|
538 |
+
key_states,
|
539 |
+
bias=attention_bias,
|
540 |
+
dropout_rng=dropout_rng,
|
541 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
542 |
+
broadcast_dropout=True,
|
543 |
+
deterministic=deterministic,
|
544 |
+
dtype=self.dtype,
|
545 |
+
precision=None,
|
546 |
+
)
|
547 |
+
|
548 |
+
# Mask heads if we want to
|
549 |
+
if layer_head_mask is not None:
|
550 |
+
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
551 |
+
|
552 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
553 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
554 |
+
|
555 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
556 |
+
return outputs
|
557 |
+
|
558 |
+
|
559 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
560 |
+
class FlaxRobertaSelfOutput(nn.Module):
|
561 |
+
config: RobertaConfig
|
562 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
563 |
+
|
564 |
+
def setup(self):
|
565 |
+
self.dense = nn.Dense(
|
566 |
+
self.config.hidden_size,
|
567 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
568 |
+
dtype=self.dtype,
|
569 |
+
)
|
570 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
571 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
572 |
+
|
573 |
+
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
574 |
+
hidden_states = self.dense(hidden_states)
|
575 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
576 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
577 |
+
return hidden_states
|
578 |
+
|
579 |
+
|
580 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
581 |
+
class FlaxRobertaAttention(nn.Module):
|
582 |
+
config: RobertaConfig
|
583 |
+
causal: bool = False
|
584 |
+
dtype: jnp.dtype = jnp.float32
|
585 |
+
|
586 |
+
def setup(self):
|
587 |
+
self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
588 |
+
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
589 |
+
|
590 |
+
def __call__(
|
591 |
+
self,
|
592 |
+
hidden_states,
|
593 |
+
attention_mask,
|
594 |
+
layer_head_mask,
|
595 |
+
key_value_states=None,
|
596 |
+
init_cache=False,
|
597 |
+
deterministic=True,
|
598 |
+
output_attentions: bool = False,
|
599 |
+
):
|
600 |
+
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
601 |
+
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
602 |
+
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
603 |
+
attn_outputs = self.self(
|
604 |
+
hidden_states,
|
605 |
+
attention_mask,
|
606 |
+
layer_head_mask=layer_head_mask,
|
607 |
+
key_value_states=key_value_states,
|
608 |
+
init_cache=init_cache,
|
609 |
+
deterministic=deterministic,
|
610 |
+
output_attentions=output_attentions,
|
611 |
+
)
|
612 |
+
attn_output = attn_outputs[0]
|
613 |
+
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
614 |
+
|
615 |
+
outputs = (hidden_states,)
|
616 |
+
|
617 |
+
if output_attentions:
|
618 |
+
outputs += (attn_outputs[1],)
|
619 |
+
|
620 |
+
return outputs
|
621 |
+
|
622 |
+
|
623 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
624 |
+
class FlaxRobertaIntermediate(nn.Module):
|
625 |
+
config: RobertaConfig
|
626 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
627 |
+
|
628 |
+
def setup(self):
|
629 |
+
self.dense = nn.Dense(
|
630 |
+
self.config.intermediate_size,
|
631 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
632 |
+
dtype=self.dtype,
|
633 |
+
)
|
634 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
635 |
+
|
636 |
+
def __call__(self, hidden_states):
|
637 |
+
hidden_states = self.dense(hidden_states)
|
638 |
+
hidden_states = self.activation(hidden_states)
|
639 |
+
return hidden_states
|
640 |
+
|
641 |
+
|
642 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
643 |
+
class FlaxRobertaOutput(nn.Module):
|
644 |
+
config: RobertaConfig
|
645 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
646 |
+
|
647 |
+
def setup(self):
|
648 |
+
self.dense = nn.Dense(
|
649 |
+
self.config.hidden_size,
|
650 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
651 |
+
dtype=self.dtype,
|
652 |
+
)
|
653 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
654 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
655 |
+
|
656 |
+
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
657 |
+
hidden_states = self.dense(hidden_states)
|
658 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
659 |
+
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
660 |
+
return hidden_states
|
661 |
+
|
662 |
+
|
663 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
|
664 |
+
class FlaxRobertaLayer(nn.Module):
|
665 |
+
config: RobertaConfig
|
666 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
667 |
+
|
668 |
+
def setup(self):
|
669 |
+
self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
670 |
+
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
671 |
+
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
672 |
+
if self.config.add_cross_attention:
|
673 |
+
self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
|
674 |
+
|
675 |
+
def __call__(
|
676 |
+
self,
|
677 |
+
hidden_states,
|
678 |
+
attention_mask,
|
679 |
+
layer_head_mask,
|
680 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
681 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
682 |
+
init_cache: bool = False,
|
683 |
+
deterministic: bool = True,
|
684 |
+
output_attentions: bool = False,
|
685 |
+
):
|
686 |
+
# Self Attention
|
687 |
+
attention_outputs = self.attention(
|
688 |
+
hidden_states,
|
689 |
+
attention_mask,
|
690 |
+
layer_head_mask=layer_head_mask,
|
691 |
+
init_cache=init_cache,
|
692 |
+
deterministic=deterministic,
|
693 |
+
output_attentions=output_attentions,
|
694 |
+
)
|
695 |
+
attention_output = attention_outputs[0]
|
696 |
+
|
697 |
+
# Cross-Attention Block
|
698 |
+
if encoder_hidden_states is not None:
|
699 |
+
cross_attention_outputs = self.crossattention(
|
700 |
+
attention_output,
|
701 |
+
attention_mask=encoder_attention_mask,
|
702 |
+
layer_head_mask=layer_head_mask,
|
703 |
+
key_value_states=encoder_hidden_states,
|
704 |
+
deterministic=deterministic,
|
705 |
+
output_attentions=output_attentions,
|
706 |
+
)
|
707 |
+
attention_output = cross_attention_outputs[0]
|
708 |
+
|
709 |
+
hidden_states = self.intermediate(attention_output)
|
710 |
+
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
711 |
+
|
712 |
+
outputs = (hidden_states,)
|
713 |
+
|
714 |
+
if output_attentions:
|
715 |
+
outputs += (attention_outputs[1],)
|
716 |
+
if encoder_hidden_states is not None:
|
717 |
+
outputs += (cross_attention_outputs[1],)
|
718 |
+
return outputs
|
719 |
+
|
720 |
+
|
721 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
722 |
+
class FlaxRobertaLayerCollection(nn.Module):
|
723 |
+
config: RobertaConfig
|
724 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
725 |
+
gradient_checkpointing: bool = False
|
726 |
+
|
727 |
+
def setup(self):
|
728 |
+
if self.gradient_checkpointing:
|
729 |
+
FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
|
730 |
+
self.layers = [
|
731 |
+
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
732 |
+
for i in range(self.config.num_hidden_layers)
|
733 |
+
]
|
734 |
+
else:
|
735 |
+
self.layers = [
|
736 |
+
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
|
737 |
+
for i in range(self.config.num_hidden_layers)
|
738 |
+
]
|
739 |
+
|
740 |
+
def __call__(
|
741 |
+
self,
|
742 |
+
hidden_states,
|
743 |
+
attention_mask,
|
744 |
+
head_mask,
|
745 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
746 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
747 |
+
init_cache: bool = False,
|
748 |
+
deterministic: bool = True,
|
749 |
+
output_attentions: bool = False,
|
750 |
+
output_hidden_states: bool = False,
|
751 |
+
return_dict: bool = True,
|
752 |
+
):
|
753 |
+
all_attentions = () if output_attentions else None
|
754 |
+
all_hidden_states = () if output_hidden_states else None
|
755 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
756 |
+
|
757 |
+
# Check if head_mask has a correct number of layers specified if desired
|
758 |
+
if head_mask is not None:
|
759 |
+
if head_mask.shape[0] != (len(self.layers)):
|
760 |
+
raise ValueError(
|
761 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
|
762 |
+
f" {head_mask.shape[0]}."
|
763 |
+
)
|
764 |
+
|
765 |
+
for i, layer in enumerate(self.layers):
|
766 |
+
if output_hidden_states:
|
767 |
+
all_hidden_states += (hidden_states,)
|
768 |
+
|
769 |
+
layer_outputs = layer(
|
770 |
+
hidden_states,
|
771 |
+
attention_mask,
|
772 |
+
head_mask[i] if head_mask is not None else None,
|
773 |
+
encoder_hidden_states,
|
774 |
+
encoder_attention_mask,
|
775 |
+
init_cache,
|
776 |
+
deterministic,
|
777 |
+
output_attentions,
|
778 |
+
)
|
779 |
+
|
780 |
+
hidden_states = layer_outputs[0]
|
781 |
+
|
782 |
+
if output_attentions:
|
783 |
+
all_attentions += (layer_outputs[1],)
|
784 |
+
|
785 |
+
if encoder_hidden_states is not None:
|
786 |
+
all_cross_attentions += (layer_outputs[2],)
|
787 |
+
|
788 |
+
if output_hidden_states:
|
789 |
+
all_hidden_states += (hidden_states,)
|
790 |
+
|
791 |
+
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
|
792 |
+
|
793 |
+
if not return_dict:
|
794 |
+
return tuple(v for v in outputs if v is not None)
|
795 |
+
|
796 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
797 |
+
last_hidden_state=hidden_states,
|
798 |
+
hidden_states=all_hidden_states,
|
799 |
+
attentions=all_attentions,
|
800 |
+
cross_attentions=all_cross_attentions,
|
801 |
+
)
|
802 |
+
|
803 |
+
|
804 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
805 |
+
class FlaxRobertaEncoder(nn.Module):
|
806 |
+
config: RobertaConfig
|
807 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
808 |
+
gradient_checkpointing: bool = False
|
809 |
+
|
810 |
+
def setup(self):
|
811 |
+
self.layer = FlaxRobertaLayerCollection(
|
812 |
+
self.config,
|
813 |
+
dtype=self.dtype,
|
814 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
815 |
+
)
|
816 |
+
|
817 |
+
def __call__(
|
818 |
+
self,
|
819 |
+
hidden_states,
|
820 |
+
attention_mask,
|
821 |
+
head_mask,
|
822 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
823 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
824 |
+
init_cache: bool = False,
|
825 |
+
deterministic: bool = True,
|
826 |
+
output_attentions: bool = False,
|
827 |
+
output_hidden_states: bool = False,
|
828 |
+
return_dict: bool = True,
|
829 |
+
):
|
830 |
+
return self.layer(
|
831 |
+
hidden_states,
|
832 |
+
attention_mask,
|
833 |
+
head_mask=head_mask,
|
834 |
+
encoder_hidden_states=encoder_hidden_states,
|
835 |
+
encoder_attention_mask=encoder_attention_mask,
|
836 |
+
init_cache=init_cache,
|
837 |
+
deterministic=deterministic,
|
838 |
+
output_attentions=output_attentions,
|
839 |
+
output_hidden_states=output_hidden_states,
|
840 |
+
return_dict=return_dict,
|
841 |
+
)
|
842 |
+
|
843 |
+
|
844 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
845 |
+
class FlaxRobertaPooler(nn.Module):
|
846 |
+
config: RobertaConfig
|
847 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
848 |
+
|
849 |
+
def setup(self):
|
850 |
+
self.dense = nn.Dense(
|
851 |
+
self.config.hidden_size,
|
852 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
853 |
+
dtype=self.dtype,
|
854 |
+
)
|
855 |
+
|
856 |
+
def __call__(self, hidden_states):
|
857 |
+
cls_hidden_state = hidden_states[:, 0]
|
858 |
+
cls_hidden_state = self.dense(cls_hidden_state)
|
859 |
+
return nn.tanh(cls_hidden_state)
|
860 |
+
|
861 |
+
|
862 |
+
class FlaxRobertaLMHead(nn.Module):
|
863 |
+
config: RobertaConfig
|
864 |
+
dtype: jnp.dtype = jnp.float32
|
865 |
+
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
866 |
+
|
867 |
+
def setup(self):
|
868 |
+
self.dense = nn.Dense(
|
869 |
+
self.config.hidden_size,
|
870 |
+
dtype=self.dtype,
|
871 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
872 |
+
)
|
873 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
874 |
+
self.decoder = nn.Dense(
|
875 |
+
self.config.vocab_size,
|
876 |
+
dtype=self.dtype,
|
877 |
+
use_bias=False,
|
878 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
879 |
+
)
|
880 |
+
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
881 |
+
|
882 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
883 |
+
hidden_states = self.dense(hidden_states)
|
884 |
+
hidden_states = ACT2FN["gelu"](hidden_states)
|
885 |
+
hidden_states = self.layer_norm(hidden_states)
|
886 |
+
|
887 |
+
if shared_embedding is not None:
|
888 |
+
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
889 |
+
else:
|
890 |
+
hidden_states = self.decoder(hidden_states)
|
891 |
+
|
892 |
+
bias = jnp.asarray(self.bias, self.dtype)
|
893 |
+
hidden_states += bias
|
894 |
+
return hidden_states
|
895 |
+
|
896 |
+
|
897 |
+
class FlaxRobertaClassificationHead(nn.Module):
|
898 |
+
config: RobertaConfig
|
899 |
+
dtype: jnp.dtype = jnp.float32
|
900 |
+
|
901 |
+
def setup(self):
|
902 |
+
self.dense = nn.Dense(
|
903 |
+
self.config.hidden_size,
|
904 |
+
dtype=self.dtype,
|
905 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
906 |
+
)
|
907 |
+
classifier_dropout = (
|
908 |
+
self.config.classifier_dropout
|
909 |
+
if self.config.classifier_dropout is not None
|
910 |
+
else self.config.hidden_dropout_prob
|
911 |
+
)
|
912 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
913 |
+
self.out_proj = nn.Dense(
|
914 |
+
self.config.num_labels,
|
915 |
+
dtype=self.dtype,
|
916 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
917 |
+
)
|
918 |
+
|
919 |
+
def __call__(self, hidden_states, deterministic=True):
|
920 |
+
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
921 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
922 |
+
hidden_states = self.dense(hidden_states)
|
923 |
+
hidden_states = nn.tanh(hidden_states)
|
924 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
925 |
+
hidden_states = self.out_proj(hidden_states)
|
926 |
+
return hidden_states
|
927 |
+
|
928 |
+
|
929 |
+
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
930 |
+
"""
|
931 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
932 |
+
models.
|
933 |
+
"""
|
934 |
+
|
935 |
+
config_class = RobertaConfig
|
936 |
+
base_model_prefix = "roberta"
|
937 |
+
|
938 |
+
module_class: nn.Module = None
|
939 |
+
|
940 |
+
def __init__(
|
941 |
+
self,
|
942 |
+
config: RobertaConfig,
|
943 |
+
input_shape: Tuple = (1, 1),
|
944 |
+
seed: int = 0,
|
945 |
+
dtype: jnp.dtype = jnp.float32,
|
946 |
+
_do_init: bool = True,
|
947 |
+
gradient_checkpointing: bool = False,
|
948 |
+
**kwargs,
|
949 |
+
):
|
950 |
+
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
951 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
952 |
+
|
953 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
954 |
+
def enable_gradient_checkpointing(self):
|
955 |
+
self._module = self.module_class(
|
956 |
+
config=self.config,
|
957 |
+
dtype=self.dtype,
|
958 |
+
gradient_checkpointing=True,
|
959 |
+
)
|
960 |
+
|
961 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
962 |
+
# init input tensors
|
963 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
964 |
+
token_type_ids = jnp.ones_like(input_ids)
|
965 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
966 |
+
attention_mask = jnp.ones_like(input_ids)
|
967 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
968 |
+
|
969 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
970 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
971 |
+
|
972 |
+
if self.config.add_cross_attention:
|
973 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
974 |
+
encoder_attention_mask = attention_mask
|
975 |
+
module_init_outputs = self.module.init(
|
976 |
+
rngs,
|
977 |
+
input_ids,
|
978 |
+
attention_mask,
|
979 |
+
token_type_ids,
|
980 |
+
position_ids,
|
981 |
+
head_mask,
|
982 |
+
encoder_hidden_states,
|
983 |
+
encoder_attention_mask,
|
984 |
+
return_dict=False,
|
985 |
+
)
|
986 |
+
else:
|
987 |
+
module_init_outputs = self.module.init(
|
988 |
+
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
989 |
+
)
|
990 |
+
|
991 |
+
random_params = module_init_outputs["params"]
|
992 |
+
|
993 |
+
if params is not None:
|
994 |
+
random_params = flatten_dict(unfreeze(random_params))
|
995 |
+
params = flatten_dict(unfreeze(params))
|
996 |
+
for missing_key in self._missing_keys:
|
997 |
+
params[missing_key] = random_params[missing_key]
|
998 |
+
self._missing_keys = set()
|
999 |
+
return freeze(unflatten_dict(params))
|
1000 |
+
else:
|
1001 |
+
return random_params
|
1002 |
+
|
1003 |
+
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
1004 |
+
def init_cache(self, batch_size, max_length):
|
1005 |
+
r"""
|
1006 |
+
Args:
|
1007 |
+
batch_size (`int`):
|
1008 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
1009 |
+
max_length (`int`):
|
1010 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
1011 |
+
cache.
|
1012 |
+
"""
|
1013 |
+
# init input variables to retrieve cache
|
1014 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
1015 |
+
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
1016 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
1017 |
+
|
1018 |
+
init_variables = self.module.init(
|
1019 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
1020 |
+
)
|
1021 |
+
return unfreeze(init_variables["cache"])
|
1022 |
+
|
1023 |
+
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1024 |
+
def __call__(
|
1025 |
+
self,
|
1026 |
+
input_ids,
|
1027 |
+
attention_mask=None,
|
1028 |
+
token_type_ids=None,
|
1029 |
+
position_ids=None,
|
1030 |
+
head_mask=None,
|
1031 |
+
encoder_hidden_states=None,
|
1032 |
+
encoder_attention_mask=None,
|
1033 |
+
params: dict = None,
|
1034 |
+
dropout_rng: jax.random.PRNGKey = None,
|
1035 |
+
train: bool = False,
|
1036 |
+
output_attentions: Optional[bool] = None,
|
1037 |
+
output_hidden_states: Optional[bool] = None,
|
1038 |
+
return_dict: Optional[bool] = None,
|
1039 |
+
past_key_values: dict = None,
|
1040 |
+
):
|
1041 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1042 |
+
output_hidden_states = (
|
1043 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1044 |
+
)
|
1045 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1046 |
+
|
1047 |
+
# init input tensors if not passed
|
1048 |
+
if token_type_ids is None:
|
1049 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
1050 |
+
|
1051 |
+
if position_ids is None:
|
1052 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
1053 |
+
|
1054 |
+
if attention_mask is None:
|
1055 |
+
attention_mask = jnp.ones_like(input_ids)
|
1056 |
+
|
1057 |
+
if head_mask is None:
|
1058 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
1059 |
+
|
1060 |
+
# Handle any PRNG if needed
|
1061 |
+
rngs = {}
|
1062 |
+
if dropout_rng is not None:
|
1063 |
+
rngs["dropout"] = dropout_rng
|
1064 |
+
|
1065 |
+
inputs = {"params": params or self.params}
|
1066 |
+
|
1067 |
+
if self.config.add_cross_attention:
|
1068 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
1069 |
+
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
1070 |
+
# changed by FlaxRobertaAttention module
|
1071 |
+
if past_key_values:
|
1072 |
+
inputs["cache"] = past_key_values
|
1073 |
+
mutable = ["cache"]
|
1074 |
+
else:
|
1075 |
+
mutable = False
|
1076 |
+
|
1077 |
+
outputs = self.module.apply(
|
1078 |
+
inputs,
|
1079 |
+
jnp.array(input_ids, dtype="i4"),
|
1080 |
+
jnp.array(attention_mask, dtype="i4"),
|
1081 |
+
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
1082 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
1083 |
+
head_mask=jnp.array(head_mask, dtype="i4"),
|
1084 |
+
encoder_hidden_states=encoder_hidden_states,
|
1085 |
+
encoder_attention_mask=encoder_attention_mask,
|
1086 |
+
deterministic=not train,
|
1087 |
+
output_attentions=output_attentions,
|
1088 |
+
output_hidden_states=output_hidden_states,
|
1089 |
+
return_dict=return_dict,
|
1090 |
+
rngs=rngs,
|
1091 |
+
mutable=mutable,
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
# add updated cache to model output
|
1095 |
+
if past_key_values is not None and return_dict:
|
1096 |
+
outputs, past_key_values = outputs
|
1097 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
1098 |
+
return outputs
|
1099 |
+
elif past_key_values is not None and not return_dict:
|
1100 |
+
outputs, past_key_values = outputs
|
1101 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
1102 |
+
|
1103 |
+
else:
|
1104 |
+
outputs = self.module.apply(
|
1105 |
+
inputs,
|
1106 |
+
jnp.array(input_ids, dtype="i4"),
|
1107 |
+
jnp.array(attention_mask, dtype="i4"),
|
1108 |
+
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
1109 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
1110 |
+
head_mask=jnp.array(head_mask, dtype="i4"),
|
1111 |
+
deterministic=not train,
|
1112 |
+
output_attentions=output_attentions,
|
1113 |
+
output_hidden_states=output_hidden_states,
|
1114 |
+
return_dict=return_dict,
|
1115 |
+
rngs=rngs,
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
return outputs
|
1119 |
+
|
1120 |
+
|
1121 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
1122 |
+
class FlaxRobertaModule(nn.Module):
|
1123 |
+
config: RobertaConfig
|
1124 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1125 |
+
add_pooling_layer: bool = True
|
1126 |
+
gradient_checkpointing: bool = False
|
1127 |
+
|
1128 |
+
def setup(self):
|
1129 |
+
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
|
1130 |
+
self.encoder = FlaxRobertaEncoder(
|
1131 |
+
self.config,
|
1132 |
+
dtype=self.dtype,
|
1133 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1134 |
+
)
|
1135 |
+
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
1136 |
+
|
1137 |
+
def __call__(
|
1138 |
+
self,
|
1139 |
+
input_ids,
|
1140 |
+
attention_mask,
|
1141 |
+
token_type_ids: Optional[jnp.ndarray] = None,
|
1142 |
+
position_ids: Optional[jnp.ndarray] = None,
|
1143 |
+
head_mask: Optional[jnp.ndarray] = None,
|
1144 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1145 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1146 |
+
init_cache: bool = False,
|
1147 |
+
deterministic: bool = True,
|
1148 |
+
output_attentions: bool = False,
|
1149 |
+
output_hidden_states: bool = False,
|
1150 |
+
return_dict: bool = True,
|
1151 |
+
):
|
1152 |
+
# make sure `token_type_ids` is correctly initialized when not passed
|
1153 |
+
if token_type_ids is None:
|
1154 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
1155 |
+
|
1156 |
+
# make sure `position_ids` is correctly initialized when not passed
|
1157 |
+
if position_ids is None:
|
1158 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
1159 |
+
|
1160 |
+
hidden_states = self.embeddings(
|
1161 |
+
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
1162 |
+
)
|
1163 |
+
outputs = self.encoder(
|
1164 |
+
hidden_states,
|
1165 |
+
attention_mask,
|
1166 |
+
head_mask=head_mask,
|
1167 |
+
deterministic=deterministic,
|
1168 |
+
encoder_hidden_states=encoder_hidden_states,
|
1169 |
+
encoder_attention_mask=encoder_attention_mask,
|
1170 |
+
init_cache=init_cache,
|
1171 |
+
output_attentions=output_attentions,
|
1172 |
+
output_hidden_states=output_hidden_states,
|
1173 |
+
return_dict=return_dict,
|
1174 |
+
)
|
1175 |
+
hidden_states = outputs[0]
|
1176 |
+
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
1177 |
+
|
1178 |
+
if not return_dict:
|
1179 |
+
# if pooled is None, don't return it
|
1180 |
+
if pooled is None:
|
1181 |
+
return (hidden_states,) + outputs[1:]
|
1182 |
+
return (hidden_states, pooled) + outputs[1:]
|
1183 |
+
|
1184 |
+
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
1185 |
+
last_hidden_state=hidden_states,
|
1186 |
+
pooler_output=pooled,
|
1187 |
+
hidden_states=outputs.hidden_states,
|
1188 |
+
attentions=outputs.attentions,
|
1189 |
+
cross_attentions=outputs.cross_attentions,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
|
1193 |
+
@add_start_docstrings(
|
1194 |
+
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
1195 |
+
ROBERTA_START_DOCSTRING,
|
1196 |
+
)
|
1197 |
+
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
1198 |
+
module_class = FlaxRobertaModule
|
1199 |
+
|
1200 |
+
|
1201 |
+
append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
|
1202 |
+
|
1203 |
+
|
1204 |
+
class FlaxRobertaForMaskedLMModule(nn.Module):
|
1205 |
+
config: RobertaConfig
|
1206 |
+
dtype: jnp.dtype = jnp.float32
|
1207 |
+
gradient_checkpointing: bool = False
|
1208 |
+
|
1209 |
+
def setup(self):
|
1210 |
+
self.roberta = FlaxRobertaModule(
|
1211 |
+
config=self.config,
|
1212 |
+
add_pooling_layer=False,
|
1213 |
+
dtype=self.dtype,
|
1214 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1215 |
+
)
|
1216 |
+
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
1217 |
+
|
1218 |
+
def __call__(
|
1219 |
+
self,
|
1220 |
+
input_ids,
|
1221 |
+
attention_mask,
|
1222 |
+
token_type_ids,
|
1223 |
+
position_ids,
|
1224 |
+
head_mask,
|
1225 |
+
deterministic: bool = True,
|
1226 |
+
output_attentions: bool = False,
|
1227 |
+
output_hidden_states: bool = False,
|
1228 |
+
return_dict: bool = True,
|
1229 |
+
):
|
1230 |
+
# Model
|
1231 |
+
outputs = self.roberta(
|
1232 |
+
input_ids,
|
1233 |
+
attention_mask,
|
1234 |
+
token_type_ids,
|
1235 |
+
position_ids,
|
1236 |
+
head_mask,
|
1237 |
+
deterministic=deterministic,
|
1238 |
+
output_attentions=output_attentions,
|
1239 |
+
output_hidden_states=output_hidden_states,
|
1240 |
+
return_dict=return_dict,
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
hidden_states = outputs[0]
|
1244 |
+
if self.config.tie_word_embeddings:
|
1245 |
+
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
1246 |
+
else:
|
1247 |
+
shared_embedding = None
|
1248 |
+
|
1249 |
+
# Compute the prediction scores
|
1250 |
+
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
1251 |
+
|
1252 |
+
if not return_dict:
|
1253 |
+
return (logits,) + outputs[1:]
|
1254 |
+
|
1255 |
+
return FlaxMaskedLMOutput(
|
1256 |
+
logits=logits,
|
1257 |
+
hidden_states=outputs.hidden_states,
|
1258 |
+
attentions=outputs.attentions,
|
1259 |
+
)
|
1260 |
+
|
1261 |
+
|
1262 |
+
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
|
1263 |
+
class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
|
1264 |
+
module_class = FlaxRobertaForMaskedLMModule
|
1265 |
+
|
1266 |
+
|
1267 |
+
append_call_sample_docstring(
|
1268 |
+
FlaxRobertaForMaskedLM,
|
1269 |
+
_CHECKPOINT_FOR_DOC,
|
1270 |
+
FlaxBaseModelOutputWithPooling,
|
1271 |
+
_CONFIG_FOR_DOC,
|
1272 |
+
mask="<mask>",
|
1273 |
+
)
|
1274 |
+
|
1275 |
+
|
1276 |
+
class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
1277 |
+
config: RobertaConfig
|
1278 |
+
dtype: jnp.dtype = jnp.float32
|
1279 |
+
gradient_checkpointing: bool = False
|
1280 |
+
|
1281 |
+
def setup(self):
|
1282 |
+
self.roberta = FlaxRobertaModule(
|
1283 |
+
config=self.config,
|
1284 |
+
dtype=self.dtype,
|
1285 |
+
add_pooling_layer=False,
|
1286 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1287 |
+
)
|
1288 |
+
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
|
1289 |
+
|
1290 |
+
def __call__(
|
1291 |
+
self,
|
1292 |
+
input_ids,
|
1293 |
+
attention_mask,
|
1294 |
+
token_type_ids,
|
1295 |
+
position_ids,
|
1296 |
+
head_mask,
|
1297 |
+
deterministic: bool = True,
|
1298 |
+
output_attentions: bool = False,
|
1299 |
+
output_hidden_states: bool = False,
|
1300 |
+
return_dict: bool = True,
|
1301 |
+
):
|
1302 |
+
# Model
|
1303 |
+
outputs = self.roberta(
|
1304 |
+
input_ids,
|
1305 |
+
attention_mask,
|
1306 |
+
token_type_ids,
|
1307 |
+
position_ids,
|
1308 |
+
head_mask,
|
1309 |
+
deterministic=deterministic,
|
1310 |
+
output_attentions=output_attentions,
|
1311 |
+
output_hidden_states=output_hidden_states,
|
1312 |
+
return_dict=return_dict,
|
1313 |
+
)
|
1314 |
+
|
1315 |
+
sequence_output = outputs[0]
|
1316 |
+
logits = self.classifier(sequence_output, deterministic=deterministic)
|
1317 |
+
|
1318 |
+
if not return_dict:
|
1319 |
+
return (logits,) + outputs[1:]
|
1320 |
+
|
1321 |
+
return FlaxSequenceClassifierOutput(
|
1322 |
+
logits=logits,
|
1323 |
+
hidden_states=outputs.hidden_states,
|
1324 |
+
attentions=outputs.attentions,
|
1325 |
+
)
|
1326 |
+
|
1327 |
+
|
1328 |
+
@add_start_docstrings(
|
1329 |
+
"""
|
1330 |
+
Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
1331 |
+
pooled output) e.g. for GLUE tasks.
|
1332 |
+
""",
|
1333 |
+
ROBERTA_START_DOCSTRING,
|
1334 |
+
)
|
1335 |
+
class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
|
1336 |
+
module_class = FlaxRobertaForSequenceClassificationModule
|
1337 |
+
|
1338 |
+
|
1339 |
+
append_call_sample_docstring(
|
1340 |
+
FlaxRobertaForSequenceClassification,
|
1341 |
+
_CHECKPOINT_FOR_DOC,
|
1342 |
+
FlaxSequenceClassifierOutput,
|
1343 |
+
_CONFIG_FOR_DOC,
|
1344 |
+
)
|
1345 |
+
|
1346 |
+
|
1347 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
|
1348 |
+
class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
1349 |
+
config: RobertaConfig
|
1350 |
+
dtype: jnp.dtype = jnp.float32
|
1351 |
+
gradient_checkpointing: bool = False
|
1352 |
+
|
1353 |
+
def setup(self):
|
1354 |
+
self.roberta = FlaxRobertaModule(
|
1355 |
+
config=self.config,
|
1356 |
+
dtype=self.dtype,
|
1357 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1358 |
+
)
|
1359 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
1360 |
+
self.classifier = nn.Dense(1, dtype=self.dtype)
|
1361 |
+
|
1362 |
+
def __call__(
|
1363 |
+
self,
|
1364 |
+
input_ids,
|
1365 |
+
attention_mask,
|
1366 |
+
token_type_ids,
|
1367 |
+
position_ids,
|
1368 |
+
head_mask,
|
1369 |
+
deterministic: bool = True,
|
1370 |
+
output_attentions: bool = False,
|
1371 |
+
output_hidden_states: bool = False,
|
1372 |
+
return_dict: bool = True,
|
1373 |
+
):
|
1374 |
+
num_choices = input_ids.shape[1]
|
1375 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
1376 |
+
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
1377 |
+
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
1378 |
+
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
1379 |
+
|
1380 |
+
# Model
|
1381 |
+
outputs = self.roberta(
|
1382 |
+
input_ids,
|
1383 |
+
attention_mask,
|
1384 |
+
token_type_ids,
|
1385 |
+
position_ids,
|
1386 |
+
head_mask,
|
1387 |
+
deterministic=deterministic,
|
1388 |
+
output_attentions=output_attentions,
|
1389 |
+
output_hidden_states=output_hidden_states,
|
1390 |
+
return_dict=return_dict,
|
1391 |
+
)
|
1392 |
+
|
1393 |
+
pooled_output = outputs[1]
|
1394 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
1395 |
+
logits = self.classifier(pooled_output)
|
1396 |
+
|
1397 |
+
reshaped_logits = logits.reshape(-1, num_choices)
|
1398 |
+
|
1399 |
+
if not return_dict:
|
1400 |
+
return (reshaped_logits,) + outputs[2:]
|
1401 |
+
|
1402 |
+
return FlaxMultipleChoiceModelOutput(
|
1403 |
+
logits=reshaped_logits,
|
1404 |
+
hidden_states=outputs.hidden_states,
|
1405 |
+
attentions=outputs.attentions,
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
|
1409 |
+
@add_start_docstrings(
|
1410 |
+
"""
|
1411 |
+
Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
1412 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
1413 |
+
""",
|
1414 |
+
ROBERTA_START_DOCSTRING,
|
1415 |
+
)
|
1416 |
+
class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
|
1417 |
+
module_class = FlaxRobertaForMultipleChoiceModule
|
1418 |
+
|
1419 |
+
|
1420 |
+
overwrite_call_docstring(
|
1421 |
+
FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
1422 |
+
)
|
1423 |
+
append_call_sample_docstring(
|
1424 |
+
FlaxRobertaForMultipleChoice,
|
1425 |
+
_CHECKPOINT_FOR_DOC,
|
1426 |
+
FlaxMultipleChoiceModelOutput,
|
1427 |
+
_CONFIG_FOR_DOC,
|
1428 |
+
)
|
1429 |
+
|
1430 |
+
|
1431 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
|
1432 |
+
class FlaxRobertaForTokenClassificationModule(nn.Module):
|
1433 |
+
config: RobertaConfig
|
1434 |
+
dtype: jnp.dtype = jnp.float32
|
1435 |
+
gradient_checkpointing: bool = False
|
1436 |
+
|
1437 |
+
def setup(self):
|
1438 |
+
self.roberta = FlaxRobertaModule(
|
1439 |
+
config=self.config,
|
1440 |
+
dtype=self.dtype,
|
1441 |
+
add_pooling_layer=False,
|
1442 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1443 |
+
)
|
1444 |
+
classifier_dropout = (
|
1445 |
+
self.config.classifier_dropout
|
1446 |
+
if self.config.classifier_dropout is not None
|
1447 |
+
else self.config.hidden_dropout_prob
|
1448 |
+
)
|
1449 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
1450 |
+
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
1451 |
+
|
1452 |
+
def __call__(
|
1453 |
+
self,
|
1454 |
+
input_ids,
|
1455 |
+
attention_mask,
|
1456 |
+
token_type_ids,
|
1457 |
+
position_ids,
|
1458 |
+
head_mask,
|
1459 |
+
deterministic: bool = True,
|
1460 |
+
output_attentions: bool = False,
|
1461 |
+
output_hidden_states: bool = False,
|
1462 |
+
return_dict: bool = True,
|
1463 |
+
):
|
1464 |
+
# Model
|
1465 |
+
outputs = self.roberta(
|
1466 |
+
input_ids,
|
1467 |
+
attention_mask,
|
1468 |
+
token_type_ids,
|
1469 |
+
position_ids,
|
1470 |
+
head_mask,
|
1471 |
+
deterministic=deterministic,
|
1472 |
+
output_attentions=output_attentions,
|
1473 |
+
output_hidden_states=output_hidden_states,
|
1474 |
+
return_dict=return_dict,
|
1475 |
+
)
|
1476 |
+
|
1477 |
+
hidden_states = outputs[0]
|
1478 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
1479 |
+
logits = self.classifier(hidden_states)
|
1480 |
+
|
1481 |
+
if not return_dict:
|
1482 |
+
return (logits,) + outputs[1:]
|
1483 |
+
|
1484 |
+
return FlaxTokenClassifierOutput(
|
1485 |
+
logits=logits,
|
1486 |
+
hidden_states=outputs.hidden_states,
|
1487 |
+
attentions=outputs.attentions,
|
1488 |
+
)
|
1489 |
+
|
1490 |
+
|
1491 |
+
@add_start_docstrings(
|
1492 |
+
"""
|
1493 |
+
Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1494 |
+
Named-Entity-Recognition (NER) tasks.
|
1495 |
+
""",
|
1496 |
+
ROBERTA_START_DOCSTRING,
|
1497 |
+
)
|
1498 |
+
class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
|
1499 |
+
module_class = FlaxRobertaForTokenClassificationModule
|
1500 |
+
|
1501 |
+
|
1502 |
+
append_call_sample_docstring(
|
1503 |
+
FlaxRobertaForTokenClassification,
|
1504 |
+
_CHECKPOINT_FOR_DOC,
|
1505 |
+
FlaxTokenClassifierOutput,
|
1506 |
+
_CONFIG_FOR_DOC,
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
|
1510 |
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
|
1511 |
+
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
1512 |
+
config: RobertaConfig
|
1513 |
+
dtype: jnp.dtype = jnp.float32
|
1514 |
+
gradient_checkpointing: bool = False
|
1515 |
+
|
1516 |
+
def setup(self):
|
1517 |
+
self.roberta = FlaxRobertaModule(
|
1518 |
+
config=self.config,
|
1519 |
+
dtype=self.dtype,
|
1520 |
+
add_pooling_layer=False,
|
1521 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1522 |
+
)
|
1523 |
+
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
1524 |
+
|
1525 |
+
def __call__(
|
1526 |
+
self,
|
1527 |
+
input_ids,
|
1528 |
+
attention_mask,
|
1529 |
+
token_type_ids,
|
1530 |
+
position_ids,
|
1531 |
+
head_mask,
|
1532 |
+
deterministic: bool = True,
|
1533 |
+
output_attentions: bool = False,
|
1534 |
+
output_hidden_states: bool = False,
|
1535 |
+
return_dict: bool = True,
|
1536 |
+
):
|
1537 |
+
# Model
|
1538 |
+
outputs = self.roberta(
|
1539 |
+
input_ids,
|
1540 |
+
attention_mask,
|
1541 |
+
token_type_ids,
|
1542 |
+
position_ids,
|
1543 |
+
head_mask,
|
1544 |
+
deterministic=deterministic,
|
1545 |
+
output_attentions=output_attentions,
|
1546 |
+
output_hidden_states=output_hidden_states,
|
1547 |
+
return_dict=return_dict,
|
1548 |
+
)
|
1549 |
+
|
1550 |
+
hidden_states = outputs[0]
|
1551 |
+
|
1552 |
+
logits = self.qa_outputs(hidden_states)
|
1553 |
+
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
1554 |
+
start_logits = start_logits.squeeze(-1)
|
1555 |
+
end_logits = end_logits.squeeze(-1)
|
1556 |
+
|
1557 |
+
if not return_dict:
|
1558 |
+
return (start_logits, end_logits) + outputs[1:]
|
1559 |
+
|
1560 |
+
return FlaxQuestionAnsweringModelOutput(
|
1561 |
+
start_logits=start_logits,
|
1562 |
+
end_logits=end_logits,
|
1563 |
+
hidden_states=outputs.hidden_states,
|
1564 |
+
attentions=outputs.attentions,
|
1565 |
+
)
|
1566 |
+
|
1567 |
+
|
1568 |
+
@add_start_docstrings(
|
1569 |
+
"""
|
1570 |
+
Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1571 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1572 |
+
""",
|
1573 |
+
ROBERTA_START_DOCSTRING,
|
1574 |
+
)
|
1575 |
+
class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
|
1576 |
+
module_class = FlaxRobertaForQuestionAnsweringModule
|
1577 |
+
|
1578 |
+
|
1579 |
+
append_call_sample_docstring(
|
1580 |
+
FlaxRobertaForQuestionAnswering,
|
1581 |
+
_CHECKPOINT_FOR_DOC,
|
1582 |
+
FlaxQuestionAnsweringModelOutput,
|
1583 |
+
_CONFIG_FOR_DOC,
|
1584 |
+
)
|
1585 |
+
|
1586 |
+
|
1587 |
+
class FlaxRobertaForCausalLMModule(nn.Module):
|
1588 |
+
config: RobertaConfig
|
1589 |
+
dtype: jnp.dtype = jnp.float32
|
1590 |
+
gradient_checkpointing: bool = False
|
1591 |
+
|
1592 |
+
def setup(self):
|
1593 |
+
self.roberta = FlaxRobertaModule(
|
1594 |
+
config=self.config,
|
1595 |
+
add_pooling_layer=False,
|
1596 |
+
dtype=self.dtype,
|
1597 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1598 |
+
)
|
1599 |
+
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
1600 |
+
|
1601 |
+
def __call__(
|
1602 |
+
self,
|
1603 |
+
input_ids,
|
1604 |
+
attention_mask,
|
1605 |
+
position_ids,
|
1606 |
+
token_type_ids: Optional[jnp.ndarray] = None,
|
1607 |
+
head_mask: Optional[jnp.ndarray] = None,
|
1608 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1609 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1610 |
+
init_cache: bool = False,
|
1611 |
+
deterministic: bool = True,
|
1612 |
+
output_attentions: bool = False,
|
1613 |
+
output_hidden_states: bool = False,
|
1614 |
+
return_dict: bool = True,
|
1615 |
+
):
|
1616 |
+
# Model
|
1617 |
+
outputs = self.roberta(
|
1618 |
+
input_ids,
|
1619 |
+
attention_mask,
|
1620 |
+
token_type_ids,
|
1621 |
+
position_ids,
|
1622 |
+
head_mask,
|
1623 |
+
encoder_hidden_states=encoder_hidden_states,
|
1624 |
+
encoder_attention_mask=encoder_attention_mask,
|
1625 |
+
init_cache=init_cache,
|
1626 |
+
deterministic=deterministic,
|
1627 |
+
output_attentions=output_attentions,
|
1628 |
+
output_hidden_states=output_hidden_states,
|
1629 |
+
return_dict=return_dict,
|
1630 |
+
)
|
1631 |
+
|
1632 |
+
hidden_states = outputs[0]
|
1633 |
+
if self.config.tie_word_embeddings:
|
1634 |
+
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
1635 |
+
else:
|
1636 |
+
shared_embedding = None
|
1637 |
+
|
1638 |
+
# Compute the prediction scores
|
1639 |
+
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
1640 |
+
|
1641 |
+
if not return_dict:
|
1642 |
+
return (logits,) + outputs[1:]
|
1643 |
+
|
1644 |
+
return FlaxCausalLMOutputWithCrossAttentions(
|
1645 |
+
logits=logits,
|
1646 |
+
hidden_states=outputs.hidden_states,
|
1647 |
+
attentions=outputs.attentions,
|
1648 |
+
cross_attentions=outputs.cross_attentions,
|
1649 |
+
)
|
1650 |
+
|
1651 |
+
|
1652 |
+
@add_start_docstrings(
|
1653 |
+
"""
|
1654 |
+
Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
1655 |
+
autoregressive tasks.
|
1656 |
+
""",
|
1657 |
+
ROBERTA_START_DOCSTRING,
|
1658 |
+
)
|
1659 |
+
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
1660 |
+
module_class = FlaxRobertaForCausalLMModule
|
1661 |
+
|
1662 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
1663 |
+
# initializing the cache
|
1664 |
+
batch_size, seq_length = input_ids.shape
|
1665 |
+
|
1666 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
1667 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
1668 |
+
# But since the decoder uses a causal mask, those positions are masked anyway.
|
1669 |
+
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
1670 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
1671 |
+
if attention_mask is not None:
|
1672 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
1673 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
1674 |
+
else:
|
1675 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
1676 |
+
|
1677 |
+
return {
|
1678 |
+
"past_key_values": past_key_values,
|
1679 |
+
"attention_mask": extended_attention_mask,
|
1680 |
+
"position_ids": position_ids,
|
1681 |
+
}
|
1682 |
+
|
1683 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
1684 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
1685 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
1686 |
+
return model_kwargs
|
1687 |
+
|
1688 |
+
|
1689 |
+
append_call_sample_docstring(
|
1690 |
+
FlaxRobertaForCausalLM,
|
1691 |
+
_CHECKPOINT_FOR_DOC,
|
1692 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
1693 |
+
_CONFIG_FOR_DOC,
|
1694 |
+
)
|
EasyLM/models/roberta/roberta_train.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import pprint
|
3 |
+
from functools import partial
|
4 |
+
import re
|
5 |
+
|
6 |
+
from tqdm import tqdm, trange
|
7 |
+
import numpy as np
|
8 |
+
import mlxu
|
9 |
+
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
13 |
+
from jax.sharding import PartitionSpec as PS
|
14 |
+
from flax.training.train_state import TrainState
|
15 |
+
|
16 |
+
from EasyLM.data import DatasetFactory
|
17 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
18 |
+
from EasyLM.optimizers import OptimizerFactory
|
19 |
+
from EasyLM.jax_utils import (
|
20 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
|
21 |
+
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
22 |
+
set_random_seed, average_metrics, get_weight_decay_mask,
|
23 |
+
make_shard_and_gather_fns, tree_apply
|
24 |
+
)
|
25 |
+
from EasyLM.models.roberta.roberta_model import (
|
26 |
+
RobertaConfig, FlaxRobertaForMaskedLMModule
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
31 |
+
seed=42,
|
32 |
+
mesh_dim='-1,1,1',
|
33 |
+
dtype='fp32',
|
34 |
+
mask_token_probability=0.15,
|
35 |
+
total_steps=10000,
|
36 |
+
load_roberta_config='',
|
37 |
+
update_roberta_config='',
|
38 |
+
load_checkpoint='',
|
39 |
+
load_dataset_state='',
|
40 |
+
log_freq=50,
|
41 |
+
save_model_freq=0,
|
42 |
+
save_milestone_freq=0,
|
43 |
+
eval_steps=0,
|
44 |
+
tokenizer=RobertaConfig.get_tokenizer_config(),
|
45 |
+
train_dataset=DatasetFactory.get_default_config(),
|
46 |
+
eval_dataset=DatasetFactory.get_default_config(),
|
47 |
+
optimizer=OptimizerFactory.get_default_config(),
|
48 |
+
checkpointer=StreamingCheckpointer.get_default_config(),
|
49 |
+
roberta=RobertaConfig.get_default_config(),
|
50 |
+
logger=mlxu.WandBLogger.get_default_config(),
|
51 |
+
log_all_worker=False,
|
52 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def main(argv):
|
57 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
58 |
+
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
59 |
+
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
60 |
+
logger = mlxu.WandBLogger(
|
61 |
+
config=FLAGS.logger,
|
62 |
+
variant=variant,
|
63 |
+
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
|
64 |
+
)
|
65 |
+
set_random_seed(FLAGS.seed)
|
66 |
+
|
67 |
+
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
|
68 |
+
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
69 |
+
if FLAGS.load_dataset_state != '':
|
70 |
+
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
71 |
+
|
72 |
+
if FLAGS.eval_steps > 0:
|
73 |
+
eval_dataset = DatasetFactory.load_dataset(
|
74 |
+
FLAGS.eval_dataset, dataset.tokenizer
|
75 |
+
)
|
76 |
+
eval_iterator = iter(eval_dataset)
|
77 |
+
|
78 |
+
seq_length = dataset.seq_length
|
79 |
+
|
80 |
+
if FLAGS.load_roberta_config != '':
|
81 |
+
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
|
82 |
+
else:
|
83 |
+
roberta_config = RobertaConfig(**FLAGS.roberta)
|
84 |
+
|
85 |
+
if FLAGS.update_roberta_config != '':
|
86 |
+
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
|
87 |
+
|
88 |
+
roberta_config.update(dict(
|
89 |
+
bos_token_id=dataset.tokenizer.bos_token_id,
|
90 |
+
eos_token_id=dataset.tokenizer.eos_token_id,
|
91 |
+
pad_token_id=dataset.tokenizer.pad_token_id,
|
92 |
+
vocab_size=dataset.vocab_size,
|
93 |
+
))
|
94 |
+
|
95 |
+
model = FlaxRobertaForMaskedLMModule(
|
96 |
+
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
|
97 |
+
)
|
98 |
+
|
99 |
+
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
|
100 |
+
FLAGS.optimizer,
|
101 |
+
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
|
102 |
+
)
|
103 |
+
|
104 |
+
def create_trainstate_from_params(params):
|
105 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
106 |
+
|
107 |
+
def init_fn(rng):
|
108 |
+
rng_generator = JaxRNG(rng)
|
109 |
+
params = model.init(
|
110 |
+
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
111 |
+
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
|
112 |
+
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
|
113 |
+
token_type_ids=None,
|
114 |
+
head_mask=None,
|
115 |
+
rngs=rng_generator(roberta_config.rng_keys()),
|
116 |
+
)
|
117 |
+
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
|
118 |
+
|
119 |
+
def train_step(train_state, rng, batch):
|
120 |
+
rng_generator = JaxRNG(rng)
|
121 |
+
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
122 |
+
def loss_and_accuracy(params):
|
123 |
+
altered_tokens = jax.random.uniform(
|
124 |
+
rng_generator(), shape=tokens.shape
|
125 |
+
) < FLAGS.mask_token_probability
|
126 |
+
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
127 |
+
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
128 |
+
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
129 |
+
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
130 |
+
random_tokens = jax.random.randint(
|
131 |
+
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
132 |
+
)
|
133 |
+
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
134 |
+
logits = model.apply(
|
135 |
+
params, inputs,
|
136 |
+
attention_mask=jnp.ones_like(inputs),
|
137 |
+
token_type_ids=None,
|
138 |
+
position_ids=None,
|
139 |
+
head_mask=None,
|
140 |
+
deterministic=False,
|
141 |
+
rngs=rng_generator(roberta_config.rng_keys()),
|
142 |
+
).logits
|
143 |
+
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
144 |
+
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
|
145 |
+
(loss, accuracy), grads = grad_fn(train_state.params)
|
146 |
+
train_state = train_state.apply_gradients(grads=grads)
|
147 |
+
metrics = dict(
|
148 |
+
loss=loss,
|
149 |
+
accuracy=accuracy,
|
150 |
+
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
|
151 |
+
gradient_norm=global_norm(grads),
|
152 |
+
param_norm=global_norm(train_state.params),
|
153 |
+
)
|
154 |
+
return train_state, rng_generator(), metrics
|
155 |
+
|
156 |
+
def eval_step(train_state, rng, batch):
|
157 |
+
rng_generator = JaxRNG(rng)
|
158 |
+
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
|
159 |
+
altered_tokens = jax.random.uniform(
|
160 |
+
rng_generator(), shape=tokens.shape
|
161 |
+
) < FLAGS.mask_token_probability
|
162 |
+
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
|
163 |
+
altered_by_mask = altered_tokens & (random_uniform < 0.8)
|
164 |
+
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
|
165 |
+
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
|
166 |
+
random_tokens = jax.random.randint(
|
167 |
+
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
|
168 |
+
)
|
169 |
+
inputs = jnp.where(altered_by_random, random_tokens, inputs)
|
170 |
+
logits = model.apply(
|
171 |
+
train_state.params, inputs,
|
172 |
+
attention_mask=jnp.ones_like(inputs),
|
173 |
+
token_type_ids=None,
|
174 |
+
position_ids=None,
|
175 |
+
head_mask=None,
|
176 |
+
deterministic=False,
|
177 |
+
rngs=rng_generator(roberta_config.rng_keys()),
|
178 |
+
).logits
|
179 |
+
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
|
180 |
+
metrics = dict(
|
181 |
+
eval_loss=loss,
|
182 |
+
eval_accuracy=accuracy,
|
183 |
+
)
|
184 |
+
return rng_generator(), metrics
|
185 |
+
|
186 |
+
train_state_shapes = jax.eval_shape(init_fn, next_rng())
|
187 |
+
train_state_partition = match_partition_rules(
|
188 |
+
RobertaConfig.get_partition_rules(), train_state_shapes
|
189 |
+
)
|
190 |
+
|
191 |
+
shard_fns, gather_fns = make_shard_and_gather_fns(
|
192 |
+
train_state_partition, train_state_shapes
|
193 |
+
)
|
194 |
+
checkpointer = StreamingCheckpointer(
|
195 |
+
FLAGS.checkpointer, logger.output_dir,
|
196 |
+
enable=jax.process_index() == 0
|
197 |
+
)
|
198 |
+
|
199 |
+
sharded_init_fn = pjit(
|
200 |
+
init_fn,
|
201 |
+
in_shardings=PS(),
|
202 |
+
out_shardings=train_state_partition
|
203 |
+
)
|
204 |
+
|
205 |
+
sharded_create_trainstate_from_params = pjit(
|
206 |
+
create_trainstate_from_params,
|
207 |
+
in_shardings=(train_state_partition.params, ),
|
208 |
+
out_shardings=train_state_partition,
|
209 |
+
donate_argnums=(0, ),
|
210 |
+
)
|
211 |
+
|
212 |
+
sharded_train_step = pjit(
|
213 |
+
train_step,
|
214 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
215 |
+
out_shardings=(train_state_partition, PS(), PS()),
|
216 |
+
donate_argnums=(0, 1),
|
217 |
+
)
|
218 |
+
|
219 |
+
sharded_eval_step = pjit(
|
220 |
+
eval_step,
|
221 |
+
in_shardings=(train_state_partition, PS(), PS()),
|
222 |
+
out_shardings=(PS(), PS()),
|
223 |
+
donate_argnums=(1,),
|
224 |
+
)
|
225 |
+
|
226 |
+
def save_checkpoint(train_state, milestone=False):
|
227 |
+
step = int(jax.device_get(train_state.step))
|
228 |
+
metadata = dict(
|
229 |
+
step=step,
|
230 |
+
variant=variant,
|
231 |
+
flags=flags_config_dict,
|
232 |
+
roberta_config=roberta_config.to_dict(),
|
233 |
+
)
|
234 |
+
checkpointer.save_all(
|
235 |
+
train_state=train_state,
|
236 |
+
gather_fns=gather_fns,
|
237 |
+
metadata=metadata,
|
238 |
+
dataset=dataset.get_state_dict(),
|
239 |
+
milestone=milestone,
|
240 |
+
)
|
241 |
+
|
242 |
+
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
|
243 |
+
with mesh:
|
244 |
+
train_state, restored_params = None, None
|
245 |
+
if FLAGS.load_checkpoint != '':
|
246 |
+
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
|
247 |
+
if load_type == 'huggingface':
|
248 |
+
restored_params = tree_apply(
|
249 |
+
shard_fns.params, roberta_config.load_pretrained(load_path)
|
250 |
+
)
|
251 |
+
train_state = None
|
252 |
+
else:
|
253 |
+
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
|
254 |
+
FLAGS.load_checkpoint, train_state_shapes, shard_fns
|
255 |
+
)
|
256 |
+
|
257 |
+
if train_state is None and restored_params is None:
|
258 |
+
# Initialize from scratch
|
259 |
+
train_state = sharded_init_fn(next_rng())
|
260 |
+
elif train_state is None and restored_params is not None:
|
261 |
+
# Restore from params but initialize train_state
|
262 |
+
train_state = sharded_create_trainstate_from_params(restored_params)
|
263 |
+
del restored_params
|
264 |
+
|
265 |
+
start_step = int(jax.device_get(train_state.step))
|
266 |
+
|
267 |
+
if FLAGS.save_model_freq > 0:
|
268 |
+
save_checkpoint(train_state)
|
269 |
+
|
270 |
+
sharded_rng = next_rng()
|
271 |
+
|
272 |
+
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
|
273 |
+
|
274 |
+
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
|
275 |
+
train_state, sharded_rng, metrics = sharded_train_step(
|
276 |
+
train_state, sharded_rng, batch
|
277 |
+
)
|
278 |
+
|
279 |
+
if step % FLAGS.log_freq == 0:
|
280 |
+
if FLAGS.eval_steps > 0:
|
281 |
+
eval_metric_list = []
|
282 |
+
for _ in range(FLAGS.eval_steps):
|
283 |
+
eval_batch, _ = next(eval_iterator)
|
284 |
+
sharded_rng, eval_metrics = sharded_eval_step(
|
285 |
+
train_state, sharded_rng, eval_batch
|
286 |
+
)
|
287 |
+
eval_metric_list.append(eval_metrics)
|
288 |
+
metrics.update(average_metrics(eval_metric_list))
|
289 |
+
|
290 |
+
log_metrics = {"step": step}
|
291 |
+
log_metrics.update(metrics)
|
292 |
+
log_metrics.update(dataset_metrics)
|
293 |
+
log_metrics = jax.device_get(log_metrics)
|
294 |
+
logger.log(log_metrics)
|
295 |
+
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
|
296 |
+
|
297 |
+
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
|
298 |
+
save_checkpoint(train_state, milestone=True)
|
299 |
+
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
|
300 |
+
save_checkpoint(train_state)
|
301 |
+
|
302 |
+
if FLAGS.save_model_freq > 0:
|
303 |
+
save_checkpoint(train_state)
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == "__main__":
|
307 |
+
mlxu.run(main)
|
EasyLM/optimizers.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
|
4 |
+
from functools import partial
|
5 |
+
import re
|
6 |
+
import dataclasses
|
7 |
+
import random
|
8 |
+
|
9 |
+
from ml_collections.config_dict import config_dict
|
10 |
+
from ml_collections import ConfigDict
|
11 |
+
import jax
|
12 |
+
import jax.numpy as jnp
|
13 |
+
import numpy as np
|
14 |
+
from absl import logging
|
15 |
+
import optax
|
16 |
+
|
17 |
+
from EasyLM.jax_utils import float_to_dtype
|
18 |
+
|
19 |
+
|
20 |
+
class OptimizerFactory(object):
|
21 |
+
""" Configurable optax optimizer factory. """
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def get_default_config(updates=None):
|
28 |
+
config = ConfigDict()
|
29 |
+
config.accumulate_gradient_steps = 1
|
30 |
+
config.type = 'adamw'
|
31 |
+
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
|
32 |
+
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
|
33 |
+
config.lion_optimizer = LionOptimizerFactory.get_default_config()
|
34 |
+
|
35 |
+
if updates is not None:
|
36 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
37 |
+
return config
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def get_optimizer(cls, config, weight_decay_mask=None):
|
41 |
+
config = cls.get_default_config(config)
|
42 |
+
if config.type == 'palm':
|
43 |
+
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
|
44 |
+
config.palm_optimizer, weight_decay_mask
|
45 |
+
)
|
46 |
+
elif config.type == 'adamw':
|
47 |
+
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
|
48 |
+
config.adamw_optimizer, weight_decay_mask
|
49 |
+
)
|
50 |
+
elif config.type == 'lion':
|
51 |
+
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
|
52 |
+
config.lion_optimizer, weight_decay_mask
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise ValueError(f'Unknown optimizer type: {config.type}')
|
56 |
+
|
57 |
+
if config.accumulate_gradient_steps > 1:
|
58 |
+
optimizer = optax.MultiSteps(
|
59 |
+
optimizer, config.accumulate_gradient_steps
|
60 |
+
)
|
61 |
+
|
62 |
+
return optimizer, optimizer_info
|
63 |
+
|
64 |
+
|
65 |
+
class PalmOptimizerFactory(object):
|
66 |
+
""" PaLM optimizer factory. This optimizer implements the optimizer
|
67 |
+
described in the PaLM paper: https://arxiv.org/abs/2204.02311
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self):
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def get_default_config(updates=None):
|
75 |
+
config = ConfigDict()
|
76 |
+
config.lr = 0.01
|
77 |
+
config.lr_warmup_steps = 10000
|
78 |
+
config.b1 = 0.9
|
79 |
+
config.b2 = 0.99
|
80 |
+
config.clip_gradient = 1.0
|
81 |
+
config.weight_decay = 1e-4
|
82 |
+
config.bf16_momentum = False
|
83 |
+
|
84 |
+
if updates is not None:
|
85 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
86 |
+
return config
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def get_optimizer(cls, config, weight_decay_mask=None):
|
90 |
+
config = cls.get_default_config(config)
|
91 |
+
|
92 |
+
def learning_rate_schedule(step):
|
93 |
+
multiplier = config.lr / 0.01
|
94 |
+
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
|
95 |
+
|
96 |
+
def weight_decay_schedule(step):
|
97 |
+
multiplier = config.weight_decay / 1e-4
|
98 |
+
return -multiplier * jnp.square(learning_rate_schedule(step))
|
99 |
+
|
100 |
+
optimizer_info = dict(
|
101 |
+
learning_rate_schedule=learning_rate_schedule,
|
102 |
+
weight_decay_schedule=weight_decay_schedule,
|
103 |
+
)
|
104 |
+
|
105 |
+
optimizer = optax.chain(
|
106 |
+
optax.clip_by_global_norm(config.clip_gradient),
|
107 |
+
optax.adafactor(
|
108 |
+
learning_rate=learning_rate_schedule,
|
109 |
+
multiply_by_parameter_scale=True,
|
110 |
+
momentum=config.b1,
|
111 |
+
decay_rate=config.b2,
|
112 |
+
factored=False,
|
113 |
+
clipping_threshold=None,
|
114 |
+
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
115 |
+
),
|
116 |
+
optax_add_scheduled_weight_decay(
|
117 |
+
weight_decay_schedule, weight_decay_mask
|
118 |
+
)
|
119 |
+
)
|
120 |
+
return optimizer, optimizer_info
|
121 |
+
|
122 |
+
|
123 |
+
class AdamWOptimizerFactory(object):
|
124 |
+
""" AdamW optimizer with cosine schedule. """
|
125 |
+
|
126 |
+
def __init__(self):
|
127 |
+
raise NotImplementedError
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def get_default_config(updates=None):
|
131 |
+
config = ConfigDict()
|
132 |
+
config.init_lr = 0.0
|
133 |
+
config.end_lr = 0.001
|
134 |
+
config.lr = 0.01
|
135 |
+
config.lr_warmup_steps = 2000
|
136 |
+
config.lr_decay_steps = 500000
|
137 |
+
config.b1 = 0.9
|
138 |
+
config.b2 = 0.95
|
139 |
+
config.clip_gradient = 1.0
|
140 |
+
config.weight_decay = 1e-4
|
141 |
+
config.bf16_momentum = False
|
142 |
+
config.multiply_by_parameter_scale = False
|
143 |
+
|
144 |
+
if updates is not None:
|
145 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
146 |
+
return config
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def get_optimizer(cls, config, weight_decay_mask=None):
|
150 |
+
config = cls.get_default_config(config)
|
151 |
+
|
152 |
+
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
153 |
+
init_value=config.init_lr,
|
154 |
+
peak_value=config.lr,
|
155 |
+
warmup_steps=config.lr_warmup_steps,
|
156 |
+
decay_steps=config.lr_decay_steps,
|
157 |
+
end_value=config.end_lr,
|
158 |
+
)
|
159 |
+
|
160 |
+
optimizer_info = dict(
|
161 |
+
learning_rate_schedule=learning_rate_schedule,
|
162 |
+
)
|
163 |
+
|
164 |
+
if config.multiply_by_parameter_scale:
|
165 |
+
optimizer = optax.chain(
|
166 |
+
optax.clip_by_global_norm(config.clip_gradient),
|
167 |
+
optax.adafactor(
|
168 |
+
learning_rate=learning_rate_schedule,
|
169 |
+
multiply_by_parameter_scale=True,
|
170 |
+
momentum=config.b1,
|
171 |
+
decay_rate=config.b2,
|
172 |
+
factored=False,
|
173 |
+
clipping_threshold=None,
|
174 |
+
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
175 |
+
),
|
176 |
+
optax_add_scheduled_weight_decay(
|
177 |
+
lambda step: -learning_rate_schedule(step) * config.weight_decay,
|
178 |
+
weight_decay_mask
|
179 |
+
)
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
optimizer = optax.chain(
|
183 |
+
optax.clip_by_global_norm(config.clip_gradient),
|
184 |
+
optax.adamw(
|
185 |
+
learning_rate=learning_rate_schedule,
|
186 |
+
weight_decay=config.weight_decay,
|
187 |
+
b1=config.b1,
|
188 |
+
b2=config.b2,
|
189 |
+
mask=weight_decay_mask,
|
190 |
+
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
191 |
+
),
|
192 |
+
)
|
193 |
+
|
194 |
+
return optimizer, optimizer_info
|
195 |
+
|
196 |
+
class LionOptimizerFactory(object):
|
197 |
+
""" Lion optimizer with cosine schedule. """
|
198 |
+
|
199 |
+
def __init__(self):
|
200 |
+
raise NotImplementedError
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def get_default_config(updates=None):
|
204 |
+
config = ConfigDict()
|
205 |
+
config.init_lr = 0.0
|
206 |
+
config.end_lr = 0.0001
|
207 |
+
config.lr = 0.001
|
208 |
+
config.lr_warmup_steps = 2000
|
209 |
+
config.lr_decay_steps = 500000
|
210 |
+
config.b1 = 0.9
|
211 |
+
config.b2 = 0.98
|
212 |
+
config.clip_gradient = 1.0
|
213 |
+
config.weight_decay = 1e-3
|
214 |
+
config.bf16_momentum = False
|
215 |
+
config.lr_schedule_type = "warmup_cosine_decay_schedule"
|
216 |
+
config.lr_decay_rate = 0.98
|
217 |
+
|
218 |
+
if updates is not None:
|
219 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
220 |
+
return config
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def get_optimizer(cls, config, weight_decay_mask=None):
|
224 |
+
config = cls.get_default_config(config)
|
225 |
+
|
226 |
+
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
|
227 |
+
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
|
228 |
+
init_value=config.init_lr,
|
229 |
+
peak_value=config.lr,
|
230 |
+
warmup_steps=config.lr_warmup_steps,
|
231 |
+
decay_steps=config.lr_decay_steps,
|
232 |
+
end_value=config.end_lr,
|
233 |
+
)
|
234 |
+
elif config.lr_schedule_type == "warmup_constant":
|
235 |
+
learning_rate_schedule = optax.join_schedules(
|
236 |
+
[
|
237 |
+
optax.linear_schedule(
|
238 |
+
init_value=config.init_lr,
|
239 |
+
end_value=config.lr,
|
240 |
+
transition_steps=config.lr_warmup_steps,
|
241 |
+
),
|
242 |
+
optax.constant_schedule(config.lr),
|
243 |
+
],
|
244 |
+
[config.lr_warmup_steps],
|
245 |
+
)
|
246 |
+
elif config.lr_schedule_type == "exponential_decay":
|
247 |
+
learning_rate_schedule = optax.exponential_decay(
|
248 |
+
init_value=config.lr,
|
249 |
+
transition_steps=config.lr_decay_steps,
|
250 |
+
decay_rate=config.lr_decay_rate,
|
251 |
+
transition_begin=0,
|
252 |
+
staircase=False,
|
253 |
+
end_value=config.end_lr,
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", or "exponential_decay"')
|
257 |
+
|
258 |
+
optimizer_info = dict(
|
259 |
+
learning_rate_schedule=learning_rate_schedule,
|
260 |
+
)
|
261 |
+
|
262 |
+
optimizer = optax.chain(
|
263 |
+
optax.clip_by_global_norm(config.clip_gradient),
|
264 |
+
optax.lion(
|
265 |
+
learning_rate=learning_rate_schedule,
|
266 |
+
weight_decay=config.weight_decay,
|
267 |
+
b1=config.b1,
|
268 |
+
b2=config.b2,
|
269 |
+
mask=weight_decay_mask,
|
270 |
+
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
|
271 |
+
),
|
272 |
+
)
|
273 |
+
|
274 |
+
return optimizer, optimizer_info
|
275 |
+
|
276 |
+
|
277 |
+
class OptaxScheduledWeightDecayState(NamedTuple):
|
278 |
+
count: jax.Array
|
279 |
+
|
280 |
+
|
281 |
+
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
282 |
+
""" Apply weight decay with schedule. """
|
283 |
+
|
284 |
+
def init_fn(params):
|
285 |
+
del params
|
286 |
+
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
|
287 |
+
|
288 |
+
def update_fn(updates, state, params):
|
289 |
+
if params is None:
|
290 |
+
raise ValueError('Params cannot be None for weight decay!')
|
291 |
+
|
292 |
+
weight_decay = schedule_fn(state.count)
|
293 |
+
updates = jax.tree_util.tree_map(
|
294 |
+
lambda g, p: g + weight_decay * p, updates, params
|
295 |
+
)
|
296 |
+
return updates, OptaxScheduledWeightDecayState(
|
297 |
+
count=optax.safe_int32_increment(state.count)
|
298 |
+
)
|
299 |
+
|
300 |
+
if mask is not None:
|
301 |
+
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
|
302 |
+
return optax.GradientTransformation(init_fn, update_fn)
|
EasyLM/scripts/__init__.py
ADDED
File without changes
|
EasyLM/scripts/benchmark_attention.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from time import time
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
import jax.flatten_util
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import mlxu
|
9 |
+
from EasyLM.bpt import blockwise_attn
|
10 |
+
from EasyLM.jax_utils import (
|
11 |
+
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
16 |
+
seed=42,
|
17 |
+
dtype='fp32',
|
18 |
+
embed_dim=2048,
|
19 |
+
n_heads=16,
|
20 |
+
ref_attn_seq_len=2048,
|
21 |
+
eff_attn_seq_len=16384,
|
22 |
+
batch_size=1,
|
23 |
+
query_chunk_size=2048,
|
24 |
+
key_chunk_size=2048,
|
25 |
+
warmup_steps=40,
|
26 |
+
steps=200,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def main(argv):
|
31 |
+
|
32 |
+
def random_kqv(rng_key, seq_len):
|
33 |
+
rng_generator = JaxRNG(rng_key)
|
34 |
+
kqv = []
|
35 |
+
for i in range(3):
|
36 |
+
kqv.append(
|
37 |
+
jax.random.normal(
|
38 |
+
rng_generator(),
|
39 |
+
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
40 |
+
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
41 |
+
)
|
42 |
+
)
|
43 |
+
return tuple(kqv)
|
44 |
+
|
45 |
+
def reference_attn(query, key, value):
|
46 |
+
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
47 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
48 |
+
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
49 |
+
mask_value = jnp.finfo(logits.dtype).min
|
50 |
+
_, q_seq_len, _, _ = query.shape
|
51 |
+
_, kv_seq_len, _, _ = key.shape
|
52 |
+
mask_shape = (q_seq_len, kv_seq_len)
|
53 |
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
54 |
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
55 |
+
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
56 |
+
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
57 |
+
weights = jax.nn.softmax(logits, axis=-1)
|
58 |
+
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
59 |
+
return out
|
60 |
+
|
61 |
+
def efficient_attention(query, key, value):
|
62 |
+
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
63 |
+
return blockwise_attn(
|
64 |
+
query, key, value,
|
65 |
+
bias=None,
|
66 |
+
deterministic=True,
|
67 |
+
dropout_rng=None,
|
68 |
+
attn_pdrop=0.0,
|
69 |
+
causal=True,
|
70 |
+
query_chunk_size=FLAGS.query_chunk_size,
|
71 |
+
key_chunk_size=FLAGS.key_chunk_size,
|
72 |
+
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
73 |
+
policy=jax.checkpoint_policies.nothing_saveable(),
|
74 |
+
precision=None,
|
75 |
+
float32_logits=True,
|
76 |
+
prevent_cse=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
@partial(jax.jit, static_argnums=(1,))
|
81 |
+
def reference_attn_forward_backward(rng_key, seq_len):
|
82 |
+
@partial(jax.grad, argnums=(0, 1, 2))
|
83 |
+
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
84 |
+
def grad_fn(query, key, value):
|
85 |
+
out = reference_attn(query, key, value)
|
86 |
+
return jnp.mean(out)
|
87 |
+
|
88 |
+
query, key, value = random_kqv(rng_key, seq_len)
|
89 |
+
return jax.flatten_util.ravel_pytree(
|
90 |
+
grad_fn(query, key, value)[1]
|
91 |
+
)[0].mean()
|
92 |
+
|
93 |
+
@partial(jax.jit, static_argnums=(1,))
|
94 |
+
def efficient_attn_forward_backward(rng_key, seq_len):
|
95 |
+
@partial(jax.grad, argnums=(0, 1, 2))
|
96 |
+
def grad_fn(query, key, value):
|
97 |
+
out = efficient_attention(query, key, value)
|
98 |
+
return jnp.mean(out)
|
99 |
+
|
100 |
+
query, key, value = random_kqv(rng_key, seq_len)
|
101 |
+
return jax.flatten_util.ravel_pytree(
|
102 |
+
grad_fn(query, key, value)[1]
|
103 |
+
)[0].mean()
|
104 |
+
|
105 |
+
|
106 |
+
set_random_seed(FLAGS.seed)
|
107 |
+
|
108 |
+
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
109 |
+
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
110 |
+
|
111 |
+
all_results = []
|
112 |
+
for i in range(FLAGS.warmup_steps):
|
113 |
+
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
114 |
+
jax.block_until_ready(all_results)
|
115 |
+
|
116 |
+
start_time = time()
|
117 |
+
all_results = []
|
118 |
+
for i in range(FLAGS.steps):
|
119 |
+
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
120 |
+
|
121 |
+
jax.block_until_ready(all_results)
|
122 |
+
elapsed_time_ref_attn = time() - start_time
|
123 |
+
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
124 |
+
|
125 |
+
|
126 |
+
all_results = []
|
127 |
+
for i in range(FLAGS.warmup_steps):
|
128 |
+
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
129 |
+
jax.block_until_ready(all_results)
|
130 |
+
|
131 |
+
|
132 |
+
start_time = time()
|
133 |
+
all_results = []
|
134 |
+
for i in range(FLAGS.steps):
|
135 |
+
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
136 |
+
|
137 |
+
jax.block_until_ready(all_results)
|
138 |
+
elapsed_time_efficient_attn = time() - start_time
|
139 |
+
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
140 |
+
|
141 |
+
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
142 |
+
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
143 |
+
print(f'Efficiency: {efficiency:.3f}')
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
mlxu.run(main)
|
148 |
+
|
149 |
+
|
150 |
+
|
EasyLM/scripts/convert_checkpoint.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script converts model checkpoint trained by EsayLM to a standard
|
2 |
+
# mspack checkpoint that can be loaded by huggingface transformers or
|
3 |
+
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
4 |
+
# used by other frameworks that integrate with huggingface transformers.
|
5 |
+
|
6 |
+
import pprint
|
7 |
+
from functools import partial
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
import mlxu
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import flax.serialization
|
13 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
14 |
+
from EasyLM.jax_utils import float_to_dtype
|
15 |
+
|
16 |
+
|
17 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
18 |
+
load_checkpoint='',
|
19 |
+
output_file='',
|
20 |
+
streaming=False,
|
21 |
+
float_dtype='bf16',
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def main(argv):
|
26 |
+
assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
|
27 |
+
params = StreamingCheckpointer.load_trainstate_checkpoint(
|
28 |
+
FLAGS.load_checkpoint, disallow_trainstate=True
|
29 |
+
)[1]['params']
|
30 |
+
|
31 |
+
if FLAGS.streaming:
|
32 |
+
StreamingCheckpointer.save_train_state_to_file(
|
33 |
+
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
params = float_to_dtype(params, FLAGS.float_dtype)
|
37 |
+
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
38 |
+
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
mlxu.run(main)
|
EasyLM/scripts/diff_checkpoint.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script converts model checkpoint trained by EsayLM to a standard
|
2 |
+
# mspack checkpoint that can be loaded by huggingface transformers or
|
3 |
+
# flax.serialization.msgpack_restore. Such conversion allows models to be
|
4 |
+
# used by other frameworks that integrate with huggingface transformers.
|
5 |
+
|
6 |
+
import pprint
|
7 |
+
from functools import partial
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import flax.serialization
|
13 |
+
import mlxu
|
14 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
15 |
+
from EasyLM.jax_utils import float_to_dtype
|
16 |
+
|
17 |
+
|
18 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
19 |
+
recover_diff=False,
|
20 |
+
load_base_checkpoint='',
|
21 |
+
load_target_checkpoint='',
|
22 |
+
output_file='',
|
23 |
+
streaming=True,
|
24 |
+
float_dtype='bf16',
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def main(argv):
|
29 |
+
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
|
30 |
+
assert FLAGS.output_file != ''
|
31 |
+
base_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
32 |
+
FLAGS.load_base_checkpoint, disallow_trainstate=True
|
33 |
+
)[1]['params']
|
34 |
+
|
35 |
+
target_params = StreamingCheckpointer.load_trainstate_checkpoint(
|
36 |
+
FLAGS.load_target_checkpoint, disallow_trainstate=True
|
37 |
+
)[1]['params']
|
38 |
+
|
39 |
+
if FLAGS.recover_diff:
|
40 |
+
params = jax.tree_util.tree_map(
|
41 |
+
lambda b, t: b + t, base_params, target_params
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
params = jax.tree_util.tree_map(
|
45 |
+
lambda b, t: t - b, base_params, target_params
|
46 |
+
)
|
47 |
+
|
48 |
+
if FLAGS.streaming:
|
49 |
+
StreamingCheckpointer.save_train_state_to_file(
|
50 |
+
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
params = float_to_dtype(params, FLAGS.float_dtype)
|
54 |
+
with mlxu.open_file(FLAGS.output, 'wb') as fout:
|
55 |
+
fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
mlxu.run(main)
|
EasyLM/scripts/lm_eval_harness.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script runs lm_eval_harness evaluations against a served language model.
|
2 |
+
# Typically, you need to run a language model server first, e.g.:
|
3 |
+
# python -m EasyLM.models.gptj.gptj_serve ...
|
4 |
+
|
5 |
+
import dataclasses
|
6 |
+
import pprint
|
7 |
+
from functools import partial
|
8 |
+
import os
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
import numpy as np
|
11 |
+
import mlxu
|
12 |
+
|
13 |
+
from flax.traverse_util import flatten_dict
|
14 |
+
from lm_eval import evaluator, tasks
|
15 |
+
from lm_eval.base import LM
|
16 |
+
|
17 |
+
from EasyLM.serving import LMClient
|
18 |
+
|
19 |
+
|
20 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
21 |
+
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
22 |
+
shots=0,
|
23 |
+
limit=0,
|
24 |
+
write_out=False,
|
25 |
+
lm_client=LMClient.get_default_config(),
|
26 |
+
logger=mlxu.WandBLogger.get_default_config(),
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class LMEvalHarnessInterface(LM):
|
31 |
+
|
32 |
+
def __init__(self, lm_client):
|
33 |
+
self.lm_client = lm_client
|
34 |
+
|
35 |
+
def greedy_until(self, inputs):
|
36 |
+
prefix, until = zip(*inputs)
|
37 |
+
return self.lm_client.greedy_until(prefix, until)
|
38 |
+
|
39 |
+
def loglikelihood_rolling(self, inputs):
|
40 |
+
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
|
41 |
+
return list(zip(loglikelihood, is_greedy))
|
42 |
+
|
43 |
+
def loglikelihood(self, inputs):
|
44 |
+
prefix, text = zip(*inputs)
|
45 |
+
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
|
46 |
+
return list(zip(loglikelihood, is_greedy))
|
47 |
+
|
48 |
+
|
49 |
+
def main(argv):
|
50 |
+
logger = mlxu.WandBLogger(
|
51 |
+
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
52 |
+
)
|
53 |
+
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
54 |
+
task_list = FLAGS.tasks.split(',')
|
55 |
+
results = evaluator.evaluate(
|
56 |
+
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
57 |
+
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
|
58 |
+
write_out=FLAGS.write_out,
|
59 |
+
)
|
60 |
+
logger.log(flatten_dict(results['results'], sep='/'))
|
61 |
+
pprint.pprint(results)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
mlxu.run(main)
|
EasyLM/scripts/lm_eval_json.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import mlxu
|
3 |
+
from EasyLM.serving import LMClient
|
4 |
+
|
5 |
+
|
6 |
+
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
7 |
+
input_file='',
|
8 |
+
output_file='',
|
9 |
+
prefix_field='prefix',
|
10 |
+
text_field='text',
|
11 |
+
until_field='until',
|
12 |
+
eval_type='loglikelihood',
|
13 |
+
lm_client=LMClient.get_default_config(),
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
def main(argv):
|
18 |
+
lm_client = LMClient(FLAGS.lm_client)
|
19 |
+
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
|
20 |
+
input_data = json.load(fin)
|
21 |
+
|
22 |
+
if FLAGS.eval_type == 'loglikelihood':
|
23 |
+
prefix = input_data[FLAGS.prefix_field]
|
24 |
+
text = input_data[FLAGS.text_field]
|
25 |
+
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
|
26 |
+
output_data = {
|
27 |
+
'loglikelihood': loglikelihoods,
|
28 |
+
'is_greedy': is_greedys,
|
29 |
+
}
|
30 |
+
elif FLAGS.eval_type == 'loglikelihood_rolling':
|
31 |
+
text = input_data[FLAGS.text_field]
|
32 |
+
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
|
33 |
+
output_data = {
|
34 |
+
'loglikelihood': loglikelihoods,
|
35 |
+
'is_greedy': is_greedys,
|
36 |
+
}
|
37 |
+
elif FLAGS.eval_type == 'greedy_until':
|
38 |
+
prefix = input_data[FLAGS.prefix_field]
|
39 |
+
until = input_data[FLAGS.until_field]
|
40 |
+
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
|
41 |
+
elif FLAGS.eval_type == 'generate':
|
42 |
+
prefix = input_data[FLAGS.prefix_field]
|
43 |
+
output_data = {'output_text': lm_client.generate(prefix)}
|
44 |
+
else:
|
45 |
+
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
|
46 |
+
|
47 |
+
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
|
48 |
+
json.dump(output_data, fout)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
mlxu.run(main)
|
EasyLM/serving.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import pprint
|
3 |
+
from functools import partial
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
from threading import Lock
|
7 |
+
import urllib
|
8 |
+
import time
|
9 |
+
from typing import List, Optional, Union
|
10 |
+
|
11 |
+
from pydantic import BaseModel
|
12 |
+
import absl.logging
|
13 |
+
from tqdm import tqdm, trange
|
14 |
+
import numpy as np
|
15 |
+
import mlxu
|
16 |
+
from ml_collections import ConfigDict
|
17 |
+
import uvicorn
|
18 |
+
from fastapi import FastAPI
|
19 |
+
import gradio as gr
|
20 |
+
import requests
|
21 |
+
from requests.exceptions import Timeout, ConnectionError
|
22 |
+
|
23 |
+
|
24 |
+
class InferenceRequest(BaseModel):
|
25 |
+
prefix_text: Optional[List[str]] = None
|
26 |
+
text: Optional[List[str]] = None
|
27 |
+
until: Optional[Union[List[str], List[List[str]]]] = None
|
28 |
+
temperature: Optional[float] = None
|
29 |
+
|
30 |
+
|
31 |
+
class ChatRequest(BaseModel):
|
32 |
+
prompt: str
|
33 |
+
context: str = ''
|
34 |
+
temperature: Optional[float] = None
|
35 |
+
|
36 |
+
|
37 |
+
class LMServer(object):
|
38 |
+
""" HTTP server for serving langauge models. """
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def get_default_config(updates=None):
|
42 |
+
config = ConfigDict()
|
43 |
+
config.host = '0.0.0.0'
|
44 |
+
config.port = 5007
|
45 |
+
config.batch_size = 1
|
46 |
+
config.logging = False
|
47 |
+
config.pre_compile = 'loglikelihood'
|
48 |
+
config.default_temperature = 1.0
|
49 |
+
config.greedy_until_max_length = 5000
|
50 |
+
config.prepend_to_prefix = ''
|
51 |
+
config.append_to_prefix = ''
|
52 |
+
config.prepend_to_text = ''
|
53 |
+
config.append_to_text = ''
|
54 |
+
config.chat_prepend_text = ''
|
55 |
+
config.chat_user_prefix = ''
|
56 |
+
config.chat_user_suffix = ''
|
57 |
+
config.chat_lm_prefix = ''
|
58 |
+
config.chat_lm_suffix = ''
|
59 |
+
config.notes = ''
|
60 |
+
|
61 |
+
if updates is not None:
|
62 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
63 |
+
return config
|
64 |
+
|
65 |
+
def __init__(self, config):
|
66 |
+
self.config = self.get_default_config(config)
|
67 |
+
self.lock = Lock()
|
68 |
+
self.app = FastAPI()
|
69 |
+
self.app.post('/loglikelihood')(self.serve_loglikelihood)
|
70 |
+
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
|
71 |
+
self.app.post('/generate')(self.serve_generate)
|
72 |
+
self.app.post('/greedy-until')(self.serve_greedy_until)
|
73 |
+
self.app.post('/chat')(self.serve_chat)
|
74 |
+
self.app.get('/ready')(self.serve_ready)
|
75 |
+
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def loglikelihood(prefix_text, text):
|
79 |
+
raise NotImplementedError()
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def loglikelihood_rolling(text):
|
83 |
+
raise NotImplementedError()
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def generate(text, temperature):
|
87 |
+
raise NotImplementedError()
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def greedy_until(prefix_text, until, max_length):
|
91 |
+
raise NotImplementedError()
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def to_list(x):
|
95 |
+
if isinstance(x, np.ndarray):
|
96 |
+
return x.tolist()
|
97 |
+
return x
|
98 |
+
|
99 |
+
def serve_ready(self):
|
100 |
+
return 'Ready!\n'
|
101 |
+
|
102 |
+
def serve_loglikelihood(self, data: InferenceRequest):
|
103 |
+
with self.lock:
|
104 |
+
if self.config.logging:
|
105 |
+
absl.logging.info(
|
106 |
+
'\n========= Serving Log Likelihood Request ========= \n'
|
107 |
+
+ pprint.pformat(data) + '\n'
|
108 |
+
)
|
109 |
+
|
110 |
+
if data.prefix_text is None:
|
111 |
+
data.prefix_text = ['' for _ in data.text]
|
112 |
+
|
113 |
+
prefix_text = [
|
114 |
+
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
115 |
+
for p in data.prefix_text
|
116 |
+
]
|
117 |
+
text = [
|
118 |
+
self.config.prepend_to_text + t + self.config.append_to_text
|
119 |
+
for t in data.text
|
120 |
+
]
|
121 |
+
|
122 |
+
log_likelihood = []
|
123 |
+
is_greedy = []
|
124 |
+
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
125 |
+
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
126 |
+
batch_text = text[i:i + self.config.batch_size]
|
127 |
+
batch_size = len(batch_text)
|
128 |
+
|
129 |
+
if batch_size < self.config.batch_size:
|
130 |
+
extra = self.config.batch_size - batch_size
|
131 |
+
batch_prefix_text.extend(['a' for _ in range(extra)])
|
132 |
+
batch_text.extend(['a' for _ in range(extra)])
|
133 |
+
|
134 |
+
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
|
135 |
+
batch_prefix_text, batch_text
|
136 |
+
)
|
137 |
+
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
138 |
+
batch_is_greedy = self.to_list(batch_is_greedy)
|
139 |
+
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
140 |
+
is_greedy.extend(batch_is_greedy[:batch_size])
|
141 |
+
|
142 |
+
output = {
|
143 |
+
'prefix_text': data.prefix_text,
|
144 |
+
'text': data.text,
|
145 |
+
'log_likelihood': log_likelihood,
|
146 |
+
'is_greedy': is_greedy,
|
147 |
+
}
|
148 |
+
if self.config.logging:
|
149 |
+
absl.logging.info(
|
150 |
+
'\n========= Output ========= \n'
|
151 |
+
+ pprint.pformat(output) + '\n'
|
152 |
+
)
|
153 |
+
|
154 |
+
return output
|
155 |
+
|
156 |
+
def serve_loglikelihood_rolling(self, data: InferenceRequest):
|
157 |
+
with self.lock:
|
158 |
+
if self.config.logging:
|
159 |
+
absl.logging.info(
|
160 |
+
'\n========= Serving Log Likelihood Request ========= \n'
|
161 |
+
+ pprint.pformat(data) + '\n'
|
162 |
+
)
|
163 |
+
|
164 |
+
text = [
|
165 |
+
self.config.prepend_to_text + t + self.config.append_to_text
|
166 |
+
for t in data.text
|
167 |
+
]
|
168 |
+
log_likelihood = []
|
169 |
+
is_greedy = []
|
170 |
+
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
171 |
+
batch_text = text[i:i + self.config.batch_size]
|
172 |
+
batch_size = len(batch_text)
|
173 |
+
|
174 |
+
if batch_size < self.config.batch_size:
|
175 |
+
extra = self.config.batch_size - batch_size
|
176 |
+
batch_text.extend(['a' for _ in range(extra)])
|
177 |
+
|
178 |
+
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
|
179 |
+
batch_text
|
180 |
+
)
|
181 |
+
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
182 |
+
batch_is_greedy = self.to_list(batch_is_greedy)
|
183 |
+
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
184 |
+
is_greedy.extend(batch_is_greedy[:batch_size])
|
185 |
+
|
186 |
+
output = {
|
187 |
+
'text': data.text,
|
188 |
+
'log_likelihood': log_likelihood,
|
189 |
+
'is_greedy': is_greedy,
|
190 |
+
}
|
191 |
+
if self.config.logging:
|
192 |
+
absl.logging.info(
|
193 |
+
'\n========= Output ========= \n'
|
194 |
+
+ pprint.pformat(output) + '\n'
|
195 |
+
)
|
196 |
+
|
197 |
+
return output
|
198 |
+
|
199 |
+
def serve_generate(self, data: InferenceRequest):
|
200 |
+
with self.lock:
|
201 |
+
if self.config.logging:
|
202 |
+
absl.logging.info(
|
203 |
+
'\n========= Serving Generate Request ========= \n'
|
204 |
+
+ pprint.pformat(data) + '\n'
|
205 |
+
)
|
206 |
+
prefix_text = [
|
207 |
+
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
208 |
+
for p in data.prefix_text
|
209 |
+
]
|
210 |
+
|
211 |
+
if data.temperature is None:
|
212 |
+
data.temperature = self.config.default_temperature
|
213 |
+
|
214 |
+
output_text = []
|
215 |
+
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
|
216 |
+
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
217 |
+
batch_size = len(batch_prefix_text)
|
218 |
+
|
219 |
+
if batch_size < self.config.batch_size:
|
220 |
+
extra = self.config.batch_size - batch_size
|
221 |
+
batch_prefix_text.extend(['a' for _ in range(extra)])
|
222 |
+
|
223 |
+
batch_output_text = self.generate(
|
224 |
+
batch_prefix_text,
|
225 |
+
temperature=data.temperature,
|
226 |
+
)
|
227 |
+
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
228 |
+
|
229 |
+
output = {
|
230 |
+
'prefix_text': data.prefix_text,
|
231 |
+
'output_text': output_text,
|
232 |
+
'temperature': data.temperature,
|
233 |
+
}
|
234 |
+
if self.config.logging:
|
235 |
+
absl.logging.info(
|
236 |
+
'\n========= Output ========= \n'
|
237 |
+
+ pprint.pformat(output) + '\n'
|
238 |
+
)
|
239 |
+
return output
|
240 |
+
|
241 |
+
def serve_greedy_until(self, data: InferenceRequest):
|
242 |
+
with self.lock:
|
243 |
+
if self.config.logging:
|
244 |
+
absl.logging.info(
|
245 |
+
'\n========= Serving Greedy Until Request ========= \n'
|
246 |
+
+ pprint.pformat(data) + '\n'
|
247 |
+
)
|
248 |
+
prefix_text = [
|
249 |
+
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
250 |
+
for p in data.prefix_text
|
251 |
+
]
|
252 |
+
until = data.until
|
253 |
+
max_length = self.config.greedy_until_max_length
|
254 |
+
|
255 |
+
output_text = []
|
256 |
+
for i in range(0, len(prefix_text), self.config.batch_size):
|
257 |
+
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
258 |
+
batch_until = until[i:i + self.config.batch_size]
|
259 |
+
batch_size = len(batch_prefix_text)
|
260 |
+
|
261 |
+
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
|
262 |
+
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
263 |
+
|
264 |
+
output = {
|
265 |
+
'prefix_text': data.prefix_text,
|
266 |
+
'until': data.until,
|
267 |
+
'max_length': max_length,
|
268 |
+
'output_text': output_text,
|
269 |
+
}
|
270 |
+
if self.config.logging:
|
271 |
+
absl.logging.info(
|
272 |
+
'\n========= Output ========= \n'
|
273 |
+
+ pprint.pformat(output) + '\n'
|
274 |
+
)
|
275 |
+
return output
|
276 |
+
|
277 |
+
def process_chat(self, prompt, context, temperature):
|
278 |
+
context = (
|
279 |
+
context + self.config.chat_user_prefix
|
280 |
+
+ prompt + self.config.chat_user_suffix
|
281 |
+
+ self.config.chat_lm_prefix
|
282 |
+
)
|
283 |
+
response = self.generate(
|
284 |
+
[self.config.chat_prepend_text + context],
|
285 |
+
temperature=float(temperature),
|
286 |
+
)[0]
|
287 |
+
context = context + response + self.config.chat_lm_suffix
|
288 |
+
return response, context
|
289 |
+
|
290 |
+
def serve_chat(self, data: ChatRequest):
|
291 |
+
if data.temperature is None:
|
292 |
+
data.temperature = self.config.default_temperature
|
293 |
+
response, context = self.process_chat(
|
294 |
+
data.prompt, data.context,
|
295 |
+
temperature=data.temperature,
|
296 |
+
)
|
297 |
+
return {
|
298 |
+
'response': response,
|
299 |
+
'context': context,
|
300 |
+
'temperature': data.temperature,
|
301 |
+
}
|
302 |
+
|
303 |
+
def create_chat_app(self):
|
304 |
+
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
|
305 |
+
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
|
306 |
+
gr.Markdown(self.config.notes)
|
307 |
+
chatbot = gr.Chatbot(label='Chat history')
|
308 |
+
msg = gr.Textbox(
|
309 |
+
placeholder='Type your message here...',
|
310 |
+
show_label=False
|
311 |
+
)
|
312 |
+
with gr.Row():
|
313 |
+
send = gr.Button('Send')
|
314 |
+
regenerate = gr.Button('Regenerate', interactive=False)
|
315 |
+
clear = gr.Button('Reset')
|
316 |
+
|
317 |
+
temp_slider = gr.Slider(
|
318 |
+
label='Temperature', minimum=0, maximum=2.0,
|
319 |
+
value=self.config.default_temperature
|
320 |
+
)
|
321 |
+
|
322 |
+
context_state = gr.State(['', ''])
|
323 |
+
|
324 |
+
def user_fn(user_message, history, context):
|
325 |
+
return {
|
326 |
+
msg: gr.update(value='', interactive=False),
|
327 |
+
clear: gr.update(interactive=False),
|
328 |
+
send: gr.update(interactive=False),
|
329 |
+
regenerate: gr.update(interactive=False),
|
330 |
+
chatbot: history + [[user_message, None]],
|
331 |
+
context_state: [context[1], context[1]],
|
332 |
+
}
|
333 |
+
|
334 |
+
def model_fn(history, context, temperature):
|
335 |
+
history[-1][1], new_context = self.process_chat(
|
336 |
+
history[-1][0], context[0], temperature
|
337 |
+
)
|
338 |
+
return {
|
339 |
+
msg: gr.update(value='', interactive=True),
|
340 |
+
clear: gr.update(interactive=True),
|
341 |
+
send: gr.update(interactive=True),
|
342 |
+
chatbot: history,
|
343 |
+
context_state: [context[0], new_context],
|
344 |
+
regenerate: gr.update(interactive=True),
|
345 |
+
}
|
346 |
+
|
347 |
+
def regenerate_fn():
|
348 |
+
return {
|
349 |
+
msg: gr.update(value='', interactive=False),
|
350 |
+
clear: gr.update(interactive=False),
|
351 |
+
send: gr.update(interactive=False),
|
352 |
+
regenerate: gr.update(interactive=False),
|
353 |
+
}
|
354 |
+
|
355 |
+
def clear_fn():
|
356 |
+
return {
|
357 |
+
chatbot: None,
|
358 |
+
msg: '',
|
359 |
+
context_state: ['', ''],
|
360 |
+
regenerate: gr.update(interactive=False),
|
361 |
+
}
|
362 |
+
|
363 |
+
msg.submit(
|
364 |
+
user_fn,
|
365 |
+
inputs=[msg, chatbot, context_state],
|
366 |
+
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
367 |
+
queue=False
|
368 |
+
).then(
|
369 |
+
model_fn,
|
370 |
+
inputs=[chatbot, context_state, temp_slider],
|
371 |
+
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
372 |
+
queue=True
|
373 |
+
)
|
374 |
+
send.click(
|
375 |
+
user_fn,
|
376 |
+
inputs=[msg, chatbot, context_state],
|
377 |
+
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
378 |
+
queue=False
|
379 |
+
).then(
|
380 |
+
model_fn,
|
381 |
+
inputs=[chatbot, context_state, temp_slider],
|
382 |
+
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
383 |
+
queue=True
|
384 |
+
)
|
385 |
+
regenerate.click(
|
386 |
+
regenerate_fn,
|
387 |
+
inputs=None,
|
388 |
+
outputs=[msg, clear, send, regenerate],
|
389 |
+
queue=False
|
390 |
+
).then(
|
391 |
+
model_fn,
|
392 |
+
inputs=[chatbot, context_state, temp_slider],
|
393 |
+
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
394 |
+
queue=True
|
395 |
+
)
|
396 |
+
clear.click(
|
397 |
+
clear_fn,
|
398 |
+
inputs=None,
|
399 |
+
outputs=[chatbot, msg, context_state, regenerate],
|
400 |
+
queue=False
|
401 |
+
)
|
402 |
+
|
403 |
+
gradio_chatbot.queue(concurrency_count=1)
|
404 |
+
return gradio_chatbot
|
405 |
+
|
406 |
+
def run(self):
|
407 |
+
if self.config.pre_compile != '':
|
408 |
+
if self.config.pre_compile == 'all':
|
409 |
+
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
|
410 |
+
else:
|
411 |
+
pre_compile = self.config.pre_compile.split(',')
|
412 |
+
|
413 |
+
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
|
414 |
+
for task in pre_compile:
|
415 |
+
if task == 'loglikelihood':
|
416 |
+
self.loglikelihood(pre_compile_data, pre_compile_data)
|
417 |
+
self.loglikelihood_rolling(pre_compile_data)
|
418 |
+
elif task == 'generate':
|
419 |
+
self.generate(pre_compile_data, 1.0)
|
420 |
+
elif task == 'greedy_until':
|
421 |
+
self.greedy_until(
|
422 |
+
pre_compile_data, pre_compile_data,
|
423 |
+
self.config.greedy_until_max_length
|
424 |
+
)
|
425 |
+
elif task == 'chat':
|
426 |
+
self.process_chat('a', 'a', 1.0)
|
427 |
+
else:
|
428 |
+
raise ValueError(f'Invalid precompile task: {task}!')
|
429 |
+
|
430 |
+
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
|
431 |
+
|
432 |
+
|
433 |
+
class LMClient(object):
|
434 |
+
""" A simple client for the LM server. """
|
435 |
+
|
436 |
+
@staticmethod
|
437 |
+
def get_default_config(updates=None):
|
438 |
+
config = ConfigDict()
|
439 |
+
config.url = 'http://localhost:5007'
|
440 |
+
config.batch_size = 1
|
441 |
+
config.wait_for_ready = True
|
442 |
+
config.dummy = False
|
443 |
+
|
444 |
+
if updates is not None:
|
445 |
+
config.update(ConfigDict(updates).copy_and_resolve_references())
|
446 |
+
return config
|
447 |
+
|
448 |
+
def __init__(self, config=None):
|
449 |
+
self.config = self.get_default_config(config)
|
450 |
+
if self.config.wait_for_ready:
|
451 |
+
self.wait_for_ready()
|
452 |
+
|
453 |
+
def wait_for_ready(self):
|
454 |
+
if self.config.dummy:
|
455 |
+
return
|
456 |
+
while True:
|
457 |
+
try:
|
458 |
+
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
|
459 |
+
return
|
460 |
+
except (Timeout, ConnectionError) as e:
|
461 |
+
time.sleep(10)
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def batched(iterator, batch_size):
|
465 |
+
batch = []
|
466 |
+
for example in iterator:
|
467 |
+
batch.append(example)
|
468 |
+
if len(batch) == batch_size:
|
469 |
+
yield batch
|
470 |
+
batch = []
|
471 |
+
if len(batch) > 0:
|
472 |
+
yield batch
|
473 |
+
|
474 |
+
def loglikelihood(self, prefix, text):
|
475 |
+
prefix, text = list(prefix), list(text)
|
476 |
+
if self.config.dummy:
|
477 |
+
return [-1.0 for _ in text], [False for _ in text]
|
478 |
+
|
479 |
+
log_likelihood = []
|
480 |
+
is_greedy = []
|
481 |
+
|
482 |
+
batched_iterator = list(zip(
|
483 |
+
self.batched(prefix, self.config.batch_size),
|
484 |
+
self.batched(text, self.config.batch_size)
|
485 |
+
))
|
486 |
+
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
|
487 |
+
response = requests.post(
|
488 |
+
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
|
489 |
+
json={'prefix_text': batch_prefix, 'text': batch_text}
|
490 |
+
).json()
|
491 |
+
log_likelihood.extend(response['log_likelihood'])
|
492 |
+
is_greedy.extend(response['is_greedy'])
|
493 |
+
|
494 |
+
return log_likelihood, is_greedy
|
495 |
+
|
496 |
+
def loglikelihood_rolling(self, text):
|
497 |
+
text = list(text)
|
498 |
+
if self.config.dummy:
|
499 |
+
return [-1.0 for _ in text], [False for _ in text]
|
500 |
+
|
501 |
+
log_likelihood = []
|
502 |
+
is_greedy = []
|
503 |
+
batched_iterator = list(self.batched(text, self.config.batch_size))
|
504 |
+
for batch_text in tqdm(batched_iterator, ncols=0):
|
505 |
+
response = requests.post(
|
506 |
+
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
|
507 |
+
json={'text': batch_text}
|
508 |
+
).json()
|
509 |
+
log_likelihood.extend(response['log_likelihood'])
|
510 |
+
is_greedy.extend(response['is_greedy'])
|
511 |
+
return log_likelihood, is_greedy
|
512 |
+
|
513 |
+
def greedy_until(self, prefix, until):
|
514 |
+
prefix, until = list(prefix), list(until)
|
515 |
+
if self.config.dummy:
|
516 |
+
results = []
|
517 |
+
for u in until:
|
518 |
+
if isinstance(u, str):
|
519 |
+
results.append('dummy text ' + u)
|
520 |
+
else:
|
521 |
+
results.append('dummy text ' + u[0])
|
522 |
+
return results
|
523 |
+
|
524 |
+
batched_iterator = list(zip(
|
525 |
+
self.batched(prefix, self.config.batch_size),
|
526 |
+
self.batched(until, self.config.batch_size),
|
527 |
+
))
|
528 |
+
output_text = []
|
529 |
+
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
|
530 |
+
response = requests.post(
|
531 |
+
urllib.parse.urljoin(self.config.url, 'greedy-until'),
|
532 |
+
json={'prefix_text': batch_prefix, 'until': batch_until}
|
533 |
+
).json()
|
534 |
+
output_text.extend(response['output_text'])
|
535 |
+
return output_text
|
536 |
+
|
537 |
+
def generate(self, prefix, temperature=None):
|
538 |
+
prefix = list(prefix)
|
539 |
+
if self.config.dummy:
|
540 |
+
return ['' for _ in prefix]
|
541 |
+
|
542 |
+
output_text = []
|
543 |
+
batched_iterator = list(self.batched(prefix, self.config.batch_size))
|
544 |
+
for batch_prefix in tqdm(batched_iterator, ncols=0):
|
545 |
+
response = requests.post(
|
546 |
+
urllib.parse.urljoin(self.config.url, 'generate'),
|
547 |
+
json={
|
548 |
+
'prefix_text': batch_prefix,
|
549 |
+
'temperature': temperature,
|
550 |
+
}
|
551 |
+
).json()
|
552 |
+
output_text.extend(response['output_text'])
|
553 |
+
return output_text
|
554 |
+
|
555 |
+
def chat(self, prompt, context, temperature=None):
|
556 |
+
if self.config.dummy:
|
557 |
+
return ''
|
558 |
+
response = requests.post(
|
559 |
+
urllib.parse.urljoin(self.config.url, 'chat'),
|
560 |
+
json={
|
561 |
+
'prompt': prompt,
|
562 |
+
'context': context,
|
563 |
+
'temperature': temperature,
|
564 |
+
}
|
565 |
+
).json()
|
566 |
+
return response['response'], response['context']
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"LlamaForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"bos_token_id": 1,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 3200,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 8640,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "llama",
|
14 |
+
"num_attention_heads": 32,
|
15 |
+
"num_hidden_layers": 26,
|
16 |
+
"num_key_value_heads": 32,
|
17 |
+
"pretraining_tp": 1,
|
18 |
+
"rms_norm_eps": 1e-06,
|
19 |
+
"rope_scaling": null,
|
20 |
+
"rope_theta": 10000.0,
|
21 |
+
"tie_word_embeddings": false,
|
22 |
+
"torch_dtype": "float16",
|
23 |
+
"transformers_version": "4.34.0.dev0",
|
24 |
+
"use_cache": true,
|
25 |
+
"vocab_size": 64256
|
26 |
+
}
|
convert_to_hf_model.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
JAX_PLATFORM_NAME=cpu python3 -m EasyLM.models.llama.convert_easylm_to_hf \
|
2 |
+
--load_checkpoint='' \
|
3 |
+
--model_size='3b' \
|
4 |
+
--output_dir='./'
|
pretrain_llama_3b.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
# Put your WANDB API key here to enable logging to wandb.
|
4 |
+
export WANDB_API_KEY=''
|
5 |
+
|
6 |
+
# TPU specific flags to improve training throughput
|
7 |
+
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
|
8 |
+
|
9 |
+
|
10 |
+
python3 -m EasyLM.models.llama.llama_train \
|
11 |
+
--jax_distributed.initialize_jax_distributed=True \
|
12 |
+
--mesh_dim='1,-1,1' \
|
13 |
+
--dtype='bf16' \
|
14 |
+
--total_steps=900000 \
|
15 |
+
--eval_freq=50000 \
|
16 |
+
--log_freq=1000 \
|
17 |
+
--save_model_freq=2000 \
|
18 |
+
--save_milestone_freq=50000 \
|
19 |
+
--load_llama_config='3b' \
|
20 |
+
--update_llama_config='' \
|
21 |
+
--load_dataset_state='' \
|
22 |
+
--load_checkpoint='' \
|
23 |
+
--tokenizer.pretrained_model_name_or_path='./' \
|
24 |
+
--optimizer.type='lion' \
|
25 |
+
--optimizer.lion_optimizer.weight_decay=1.0 \
|
26 |
+
--optimizer.lion_optimizer.lr_schedule_type='warmup_constant' \
|
27 |
+
--optimizer.lion_optimizer.lr=3e-4 \
|
28 |
+
--optimizer.lion_optimizer.end_lr=3e-5 \
|
29 |
+
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
30 |
+
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
31 |
+
--optimizer.lion_optimizer.bf16_momentum=True \
|
32 |
+
--train_dataset.type='huggingface' \
|
33 |
+
--train_dataset.text_processor.fields='text' \
|
34 |
+
--train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
|
35 |
+
--train_dataset.huggingface_dataset.split='train' \
|
36 |
+
--train_dataset.huggingface_dataset.seq_length=2048 \
|
37 |
+
--train_dataset.huggingface_dataset.batch_size=64 \
|
38 |
+
--eval_dataset.type='huggingface' \
|
39 |
+
--eval_dataset.text_processor.fields='text' \
|
40 |
+
--eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_first_stage' \
|
41 |
+
--eval_dataset.huggingface_dataset.split='validation' \
|
42 |
+
--eval_dataset.huggingface_dataset.seq_length=2048 \
|
43 |
+
--eval_dataset.huggingface_dataset.batch_size=64 \
|
44 |
+
--checkpointer.save_optimizer_state=True \
|
45 |
+
--logger.online=True \
|
46 |
+
--logger.prefix='EasyLM' \
|
47 |
+
--logger.project="llama-3b-finnish-v2" \
|
48 |
+
--logger.output_dir="gs://finnish-nlp-research-us/llama-3b-v2-checkpoint" \
|
49 |
+
--logger.wandb_dir="./"
|
50 |
+
|