feat: vmap optimizer (#166)
Browse files- src/dalle_mini/model/modeling.py +35 -23
- src/dalle_mini/model/partitions.py +1 -1
- tools/train/config/mega/config.json +2 -2
- tools/train/train.py +118 -89
src/dalle_mini/model/modeling.py
CHANGED
@@ -946,15 +946,6 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
946 |
if output_hidden_states:
|
947 |
all_hidden_states += (hidden_states,)
|
948 |
|
949 |
-
# postln is already applied in every layer
|
950 |
-
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
|
951 |
-
hidden_states = norm(
|
952 |
-
self.config.ln_type,
|
953 |
-
dtype=self.dtype,
|
954 |
-
epsilon=1e-05,
|
955 |
-
use_scale=self.config.force_ln_scale,
|
956 |
-
)(hidden_states)
|
957 |
-
|
958 |
outputs = [
|
959 |
hidden_states,
|
960 |
all_hidden_states,
|
@@ -1034,7 +1025,7 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
1034 |
self.config,
|
1035 |
dtype=self.dtype,
|
1036 |
add_norm=self.config.ln_positions == "postln",
|
1037 |
-
name="
|
1038 |
)(
|
1039 |
hidden_states,
|
1040 |
attention_mask,
|
@@ -1086,15 +1077,6 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
1086 |
if output_hidden_states:
|
1087 |
all_hidden_states += (hidden_states,)
|
1088 |
|
1089 |
-
# postln is already applied in every layer
|
1090 |
-
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
|
1091 |
-
hidden_states = norm(
|
1092 |
-
self.config.ln_type,
|
1093 |
-
dtype=self.dtype,
|
1094 |
-
epsilon=1e-05,
|
1095 |
-
use_scale=self.config.force_ln_scale,
|
1096 |
-
)(hidden_states)
|
1097 |
-
|
1098 |
outputs = [
|
1099 |
hidden_states,
|
1100 |
all_hidden_states,
|
@@ -1146,6 +1128,17 @@ class FlaxBartEncoder(nn.Module):
|
|
1146 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1147 |
)
|
1148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1149 |
def __call__(
|
1150 |
self,
|
1151 |
input_ids,
|
@@ -1177,11 +1170,16 @@ class FlaxBartEncoder(nn.Module):
|
|
1177 |
return_dict=return_dict,
|
1178 |
)
|
1179 |
|
|
|
|
|
|
|
|
|
|
|
1180 |
if not return_dict:
|
1181 |
-
return outputs
|
1182 |
|
1183 |
return FlaxBaseModelOutput(
|
1184 |
-
last_hidden_state=
|
1185 |
hidden_states=outputs.hidden_states,
|
1186 |
attentions=outputs.attentions,
|
1187 |
)
|
@@ -1223,6 +1221,15 @@ class FlaxBartDecoder(nn.Module):
|
|
1223 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1224 |
)
|
1225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1226 |
def __call__(
|
1227 |
self,
|
1228 |
input_ids,
|
@@ -1260,11 +1267,16 @@ class FlaxBartDecoder(nn.Module):
|
|
1260 |
return_dict=return_dict,
|
1261 |
)
|
1262 |
|
|
|
|
|
|
|
|
|
|
|
1263 |
if not return_dict:
|
1264 |
-
return outputs
|
1265 |
|
1266 |
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1267 |
-
last_hidden_state=
|
1268 |
hidden_states=outputs.hidden_states,
|
1269 |
attentions=outputs.attentions,
|
1270 |
cross_attentions=outputs.cross_attentions,
|
|
|
946 |
if output_hidden_states:
|
947 |
all_hidden_states += (hidden_states,)
|
948 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
outputs = [
|
950 |
hidden_states,
|
951 |
all_hidden_states,
|
|
|
1025 |
self.config,
|
1026 |
dtype=self.dtype,
|
1027 |
add_norm=self.config.ln_positions == "postln",
|
1028 |
+
name="FlaxBartDecoderLayers",
|
1029 |
)(
|
1030 |
hidden_states,
|
1031 |
attention_mask,
|
|
|
1077 |
if output_hidden_states:
|
1078 |
all_hidden_states += (hidden_states,)
|
1079 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1080 |
outputs = [
|
1081 |
hidden_states,
|
1082 |
all_hidden_states,
|
|
|
1128 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1129 |
)
|
1130 |
|
1131 |
+
# postln is already applied in every layer
|
1132 |
+
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
|
1133 |
+
self.final_ln = norm(
|
1134 |
+
self.config.ln_type,
|
1135 |
+
dtype=self.dtype,
|
1136 |
+
epsilon=1e-05,
|
1137 |
+
use_scale=self.config.force_ln_scale,
|
1138 |
+
)
|
1139 |
+
else:
|
1140 |
+
self.final_ln = None
|
1141 |
+
|
1142 |
def __call__(
|
1143 |
self,
|
1144 |
input_ids,
|
|
|
1170 |
return_dict=return_dict,
|
1171 |
)
|
1172 |
|
1173 |
+
if self.final_ln is None:
|
1174 |
+
final_output = outputs[0]
|
1175 |
+
else:
|
1176 |
+
final_output = self.final_ln(outputs[0])
|
1177 |
+
|
1178 |
if not return_dict:
|
1179 |
+
return (final_output,) + outputs[1:]
|
1180 |
|
1181 |
return FlaxBaseModelOutput(
|
1182 |
+
last_hidden_state=final_output,
|
1183 |
hidden_states=outputs.hidden_states,
|
1184 |
attentions=outputs.attentions,
|
1185 |
)
|
|
|
1221 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1222 |
)
|
1223 |
|
1224 |
+
# postln is already applied in every layer
|
1225 |
+
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
|
1226 |
+
self.final_ln = norm(
|
1227 |
+
self.config.ln_type,
|
1228 |
+
dtype=self.dtype,
|
1229 |
+
epsilon=1e-05,
|
1230 |
+
use_scale=self.config.force_ln_scale,
|
1231 |
+
)
|
1232 |
+
|
1233 |
def __call__(
|
1234 |
self,
|
1235 |
input_ids,
|
|
|
1267 |
return_dict=return_dict,
|
1268 |
)
|
1269 |
|
1270 |
+
if self.final_ln is None:
|
1271 |
+
final_output = outputs[0]
|
1272 |
+
else:
|
1273 |
+
final_output = self.final_ln(outputs[0])
|
1274 |
+
|
1275 |
if not return_dict:
|
1276 |
+
return (final_output,) + outputs[1:]
|
1277 |
|
1278 |
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1279 |
+
last_hidden_state=final_output,
|
1280 |
hidden_states=outputs.hidden_states,
|
1281 |
attentions=outputs.attentions,
|
1282 |
cross_attentions=outputs.cross_attentions,
|
src/dalle_mini/model/partitions.py
CHANGED
@@ -65,7 +65,7 @@ def set_partitions(in_dict, use_scan):
|
|
65 |
print(f"Unmatched -> {k}")
|
66 |
l = list(result.keys())
|
67 |
if use_scan:
|
68 |
-
# add None dimension to
|
69 |
result = {
|
70 |
k: (P(*(None,) + v) if v is not None else None)
|
71 |
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
|
|
|
65 |
print(f"Unmatched -> {k}")
|
66 |
l = list(result.keys())
|
67 |
if use_scan:
|
68 |
+
# add None dimension to layers
|
69 |
result = {
|
70 |
k: (P(*(None,) + v) if v is not None else None)
|
71 |
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
|
tools/train/config/mega/config.json
CHANGED
@@ -7,14 +7,14 @@
|
|
7 |
"decoder_attention_heads": 32,
|
8 |
"decoder_ffn_dim": 4096,
|
9 |
"decoder_layerdrop": 0.0,
|
10 |
-
"decoder_layers":
|
11 |
"decoder_start_token_id": 16384,
|
12 |
"do_sample": true,
|
13 |
"dropout": 0.0,
|
14 |
"encoder_attention_heads": 32,
|
15 |
"encoder_ffn_dim": 4096,
|
16 |
"encoder_layerdrop": 0.0,
|
17 |
-
"encoder_layers":
|
18 |
"encoder_vocab_size": 50272,
|
19 |
"eos_token_id": 16385,
|
20 |
"force_ln_scale": false,
|
|
|
7 |
"decoder_attention_heads": 32,
|
8 |
"decoder_ffn_dim": 4096,
|
9 |
"decoder_layerdrop": 0.0,
|
10 |
+
"decoder_layers": 24,
|
11 |
"decoder_start_token_id": 16384,
|
12 |
"do_sample": true,
|
13 |
"dropout": 0.0,
|
14 |
"encoder_attention_heads": 32,
|
15 |
"encoder_ffn_dim": 4096,
|
16 |
"encoder_layerdrop": 0.0,
|
17 |
+
"encoder_layers": 24,
|
18 |
"encoder_vocab_size": 50272,
|
19 |
"eos_token_id": 16385,
|
20 |
"force_ln_scale": false,
|
tools/train/train.py
CHANGED
@@ -38,11 +38,10 @@ import optax
|
|
38 |
import transformers
|
39 |
import wandb
|
40 |
from datasets import Dataset
|
|
|
41 |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
42 |
from flax.serialization import from_bytes, to_bytes
|
43 |
-
from flax.training import train_state
|
44 |
from flax.training.common_utils import onehot
|
45 |
-
from jax import ShapeDtypeStruct
|
46 |
from jax.experimental import PartitionSpec, maps
|
47 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
48 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
@@ -526,60 +525,78 @@ class TrainingArguments:
|
|
526 |
self.dp_devices = jax.device_count() // self.mp_devices
|
527 |
|
528 |
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
dropout_rng: jnp.ndarray = None
|
531 |
epoch: int = 0
|
532 |
train_time: float = 0.0 # total time the model trained
|
533 |
train_samples: int = 0 # number of samples seen
|
534 |
|
535 |
def apply_gradients(self, *, grads, **kwargs):
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
return self.replace(
|
542 |
step=self.step + 1,
|
543 |
-
params=
|
544 |
-
opt_state=
|
545 |
**kwargs,
|
546 |
)
|
547 |
|
548 |
@classmethod
|
549 |
def create(cls, *, apply_fn, params, tx, **kwargs):
|
550 |
-
opt_state =
|
|
|
|
|
|
|
|
|
|
|
551 |
return cls(
|
552 |
step=0,
|
553 |
apply_fn=apply_fn,
|
554 |
params=params,
|
555 |
tx=tx,
|
556 |
-
opt_state=opt_state,
|
557 |
**kwargs,
|
558 |
)
|
559 |
|
560 |
-
@staticmethod
|
561 |
-
def unscan(params):
|
562 |
-
params = unfreeze(params)
|
563 |
-
for l in ["encoder", "decoder"]:
|
564 |
-
params["model"][l]["layers"] = jax.tree_map(
|
565 |
-
lambda x: {f"{i}": x[i] for i in range(len(x))},
|
566 |
-
params["model"][l]["layers"],
|
567 |
-
)
|
568 |
-
params = freeze(params)
|
569 |
-
return params
|
570 |
-
|
571 |
-
@staticmethod
|
572 |
-
def rescan(params):
|
573 |
-
params = unfreeze(params)
|
574 |
-
for l in ["encoder", "decoder"]:
|
575 |
-
params["model"][l]["layers"] = jax.tree_map(
|
576 |
-
lambda x: jnp.stack([x[f"{i}"] for i in range(len(x))]),
|
577 |
-
params["model"][l]["layers"],
|
578 |
-
is_leaf=lambda x: "0" in x,
|
579 |
-
)
|
580 |
-
params = freeze(params)
|
581 |
-
return params
|
582 |
-
|
583 |
|
584 |
def main():
|
585 |
# See all possible arguments by passing the --help flag to this script.
|
@@ -792,23 +809,6 @@ def main():
|
|
792 |
|
793 |
learning_rate_fn = create_learning_rate_fn()
|
794 |
|
795 |
-
# reshape params to split scanned layers for optimizers
|
796 |
-
if model.config.use_scan:
|
797 |
-
params_struct = unfreeze(model.params)
|
798 |
-
for l in ["encoder", "decoder"]:
|
799 |
-
params_struct["model"][l]["layers"] = jax.tree_map(
|
800 |
-
lambda x: {
|
801 |
-
f"{i}": ShapeDtypeStruct(shape=x.shape[1:], dtype=x.dtype)
|
802 |
-
for i in range(len(x))
|
803 |
-
},
|
804 |
-
params_struct["model"][l]["layers"],
|
805 |
-
)
|
806 |
-
params_struct = freeze(params_struct)
|
807 |
-
|
808 |
-
else:
|
809 |
-
params_struct = model.params
|
810 |
-
opt_param_spec = set_partitions(params_struct, False)
|
811 |
-
|
812 |
# create adam optimizer
|
813 |
if training_args.optim == "distributed_shampoo":
|
814 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
@@ -820,7 +820,12 @@ def main():
|
|
820 |
"sqrt_n": GraftingType.SQRT_N,
|
821 |
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
|
822 |
}[training_args.graft_type]
|
823 |
-
|
|
|
|
|
|
|
|
|
|
|
824 |
learning_rate_fn,
|
825 |
block_size=training_args.block_size,
|
826 |
beta1=training_args.beta1,
|
@@ -836,11 +841,7 @@ def main():
|
|
836 |
graft_type=graft_type,
|
837 |
nesterov=False,
|
838 |
exponent_override=0,
|
839 |
-
statistics_partition_spec=
|
840 |
-
None, training_args.shard_shampoo_across, None
|
841 |
-
)
|
842 |
-
if training_args.shard_shampoo_across != "2d"
|
843 |
-
else PartitionSpec(None, "dp", "mp"),
|
844 |
preconditioner_partition_spec=PartitionSpec(
|
845 |
training_args.shard_shampoo_across, None, None
|
846 |
)
|
@@ -860,14 +861,18 @@ def main():
|
|
860 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
861 |
)
|
862 |
# get the real optimizer and helper functions
|
863 |
-
update_fn =
|
864 |
-
|
865 |
-
optimizer =
|
866 |
-
opt_fn =
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
|
|
|
|
|
|
|
|
871 |
|
872 |
elif training_args.optim == "adam":
|
873 |
optimizer = optax.adamw(
|
@@ -876,6 +881,8 @@ def main():
|
|
876 |
b2=training_args.beta2,
|
877 |
eps=training_args.adam_epsilon,
|
878 |
)
|
|
|
|
|
879 |
elif training_args.optim == "adafactor":
|
880 |
# We use the default parameters here to initialize adafactor,
|
881 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
@@ -883,44 +890,66 @@ def main():
|
|
883 |
learning_rate=learning_rate_fn,
|
884 |
clipping_threshold=training_args.max_grad_norm,
|
885 |
)
|
|
|
886 |
|
887 |
# get PartitionSpec for optimizer state
|
888 |
-
def get_opt_state_spec_and_shape(
|
889 |
# get opt_state shape without actual init
|
890 |
-
opt_state_shape =
|
|
|
|
|
|
|
|
|
|
|
891 |
|
892 |
-
if training_args.optim == "
|
|
|
|
|
893 |
|
894 |
-
|
|
|
|
|
895 |
if isinstance(x, FrozenDict):
|
896 |
# variables with same structure as params
|
897 |
-
return
|
898 |
else:
|
899 |
# other variables such as count
|
900 |
return None
|
901 |
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
912 |
|
913 |
-
elif training_args.optim == "distributed_shampoo":
|
914 |
-
opt_state_spec = opt_fn.pspec_fn(
|
915 |
-
params=params_struct,
|
916 |
-
params_partition_spec=param_spec,
|
917 |
-
partition_spec_for_statistics=PartitionSpec(None, "dp", None),
|
918 |
-
)
|
919 |
else:
|
920 |
raise NotImplementedError
|
921 |
-
return opt_state_spec, opt_state_shape
|
922 |
|
923 |
-
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(
|
924 |
|
925 |
# create a mesh
|
926 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
|
|
38 |
import transformers
|
39 |
import wandb
|
40 |
from datasets import Dataset
|
41 |
+
from flax import core, struct, traverse_util
|
42 |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
43 |
from flax.serialization import from_bytes, to_bytes
|
|
|
44 |
from flax.training.common_utils import onehot
|
|
|
45 |
from jax.experimental import PartitionSpec, maps
|
46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
525 |
self.dp_devices = jax.device_count() // self.mp_devices
|
526 |
|
527 |
|
528 |
+
def split_params(data):
|
529 |
+
"""Split params between scanned and non-scanned"""
|
530 |
+
flat = traverse_util.flatten_dict(unfreeze(data))
|
531 |
+
split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
|
532 |
+
for k, v in flat.items():
|
533 |
+
if "FlaxBartEncoderLayers" in k:
|
534 |
+
split["scanned_encoder"][k] = v
|
535 |
+
elif "FlaxBartDecoderLayers" in k:
|
536 |
+
split["scanned_decoder"][k] = v
|
537 |
+
else:
|
538 |
+
split["standard"][k] = v
|
539 |
+
for k, v in split.items():
|
540 |
+
split[k] = freeze(traverse_util.unflatten_dict(v))
|
541 |
+
return split
|
542 |
+
|
543 |
+
|
544 |
+
def unsplit_params(data):
|
545 |
+
flat = {}
|
546 |
+
for k in ["standard", "scanned_encoder", "scanned_decoder"]:
|
547 |
+
flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
|
548 |
+
return freeze(traverse_util.unflatten_dict(flat))
|
549 |
+
|
550 |
+
|
551 |
+
class TrainState(struct.PyTreeNode):
|
552 |
+
step: int
|
553 |
+
params: core.FrozenDict[str, Any]
|
554 |
+
opt_state: optax.OptState
|
555 |
+
apply_fn: Callable = struct.field(pytree_node=False)
|
556 |
+
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
557 |
dropout_rng: jnp.ndarray = None
|
558 |
epoch: int = 0
|
559 |
train_time: float = 0.0 # total time the model trained
|
560 |
train_samples: int = 0 # number of samples seen
|
561 |
|
562 |
def apply_gradients(self, *, grads, **kwargs):
|
563 |
+
grads = split_params(grads)
|
564 |
+
params = split_params(self.params)
|
565 |
+
opt_state = {}
|
566 |
+
# we loop over keys: "standard", "scanned_encoder", "scanned_decoder"
|
567 |
+
for k, param in params.items():
|
568 |
+
update_fn = self.tx[k].update
|
569 |
+
if "scanned" in k:
|
570 |
+
update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0))
|
571 |
+
updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param)
|
572 |
+
params[k] = optax.apply_updates(param, updates)
|
573 |
+
opt_state[k] = new_opt_state
|
574 |
+
params = unsplit_params(params)
|
575 |
+
|
576 |
return self.replace(
|
577 |
step=self.step + 1,
|
578 |
+
params=params,
|
579 |
+
opt_state=freeze(opt_state),
|
580 |
**kwargs,
|
581 |
)
|
582 |
|
583 |
@classmethod
|
584 |
def create(cls, *, apply_fn, params, tx, **kwargs):
|
585 |
+
opt_state = {}
|
586 |
+
for k, p in split_params(params).items():
|
587 |
+
init_fn = tx[k].init
|
588 |
+
if "scanned" in k:
|
589 |
+
init_fn = jax.vmap(init_fn)
|
590 |
+
opt_state[k] = init_fn(p)
|
591 |
return cls(
|
592 |
step=0,
|
593 |
apply_fn=apply_fn,
|
594 |
params=params,
|
595 |
tx=tx,
|
596 |
+
opt_state=freeze(opt_state),
|
597 |
**kwargs,
|
598 |
)
|
599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
|
601 |
def main():
|
602 |
# See all possible arguments by passing the --help flag to this script.
|
|
|
809 |
|
810 |
learning_rate_fn = create_learning_rate_fn()
|
811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
812 |
# create adam optimizer
|
813 |
if training_args.optim == "distributed_shampoo":
|
814 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
|
820 |
"sqrt_n": GraftingType.SQRT_N,
|
821 |
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
|
822 |
}[training_args.graft_type]
|
823 |
+
statistics_partition_spec = (
|
824 |
+
PartitionSpec(None, training_args.shard_shampoo_across, None)
|
825 |
+
if training_args.shard_shampoo_across != "2d"
|
826 |
+
else PartitionSpec(None, "dp", "mp")
|
827 |
+
)
|
828 |
+
opt = distributed_shampoo(
|
829 |
learning_rate_fn,
|
830 |
block_size=training_args.block_size,
|
831 |
beta1=training_args.beta1,
|
|
|
841 |
graft_type=graft_type,
|
842 |
nesterov=False,
|
843 |
exponent_override=0,
|
844 |
+
statistics_partition_spec=statistics_partition_spec,
|
|
|
|
|
|
|
|
|
845 |
preconditioner_partition_spec=PartitionSpec(
|
846 |
training_args.shard_shampoo_across, None, None
|
847 |
)
|
|
|
861 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
862 |
)
|
863 |
# get the real optimizer and helper functions
|
864 |
+
update_fn = opt.update
|
865 |
+
|
866 |
+
optimizer = {}
|
867 |
+
opt_fn = {}
|
868 |
+
for k, p in split_params(model.params).items():
|
869 |
+
if "scanned" in k:
|
870 |
+
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
871 |
+
optimizer[k] = opt.init(p)
|
872 |
+
opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
873 |
+
optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn
|
874 |
+
)
|
875 |
+
optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn)
|
876 |
|
877 |
elif training_args.optim == "adam":
|
878 |
optimizer = optax.adamw(
|
|
|
881 |
b2=training_args.beta2,
|
882 |
eps=training_args.adam_epsilon,
|
883 |
)
|
884 |
+
optimizer = {k: optimizer for k in split_params(model.params)}
|
885 |
+
|
886 |
elif training_args.optim == "adafactor":
|
887 |
# We use the default parameters here to initialize adafactor,
|
888 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
|
|
890 |
learning_rate=learning_rate_fn,
|
891 |
clipping_threshold=training_args.max_grad_norm,
|
892 |
)
|
893 |
+
optimizer = {k: optimizer for k in split_params(model.params)}
|
894 |
|
895 |
# get PartitionSpec for optimizer state
|
896 |
+
def get_opt_state_spec_and_shape():
|
897 |
# get opt_state shape without actual init
|
898 |
+
opt_state_shape = {}
|
899 |
+
for k, p in split_params(model.params).items():
|
900 |
+
if "scanned" not in k:
|
901 |
+
opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
|
902 |
+
else:
|
903 |
+
opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
|
904 |
|
905 |
+
if training_args.optim == "adafactor":
|
906 |
+
# factorized state must be replicated (rank different than params)
|
907 |
+
opt_state_spec = {k: None for k in split_params(model.params)}
|
908 |
|
909 |
+
elif training_args.optim in ["adam", "distributed_shampoo"]:
|
910 |
+
|
911 |
+
def _opt_state_spec_per_leaf(x, spec):
|
912 |
if isinstance(x, FrozenDict):
|
913 |
# variables with same structure as params
|
914 |
+
return spec
|
915 |
else:
|
916 |
# other variables such as count
|
917 |
return None
|
918 |
|
919 |
+
split_spec = split_params(set_partitions(model.params, False))
|
920 |
+
opt_state_spec = {}
|
921 |
+
for k, p in split_params(model.params).items():
|
922 |
+
if "scanned" in k:
|
923 |
+
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
924 |
+
if training_args.optim == "adam":
|
925 |
+
opt_state_spec[k] = jax.tree_map(
|
926 |
+
_opt_state_spec_per_leaf,
|
927 |
+
opt_state_shape[k],
|
928 |
+
split_spec[k],
|
929 |
+
# return None spec for empty elements
|
930 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
931 |
+
)
|
932 |
+
elif training_args.optim == "distributed_shampoo":
|
933 |
+
opt_state_spec[k] = opt_fn[k].pspec_fn(
|
934 |
+
p,
|
935 |
+
split_spec[k],
|
936 |
+
statistics_partition_spec,
|
937 |
+
)
|
938 |
+
# add dimension for scanned params
|
939 |
+
if "scanned" in k:
|
940 |
+
opt_state_spec[k] = jax.tree_map(
|
941 |
+
lambda x: PartitionSpec(*(None,) + x)
|
942 |
+
if x is not None
|
943 |
+
else None,
|
944 |
+
opt_state_spec[k],
|
945 |
+
is_leaf=lambda x: isinstance(x, PartitionSpec),
|
946 |
+
)
|
947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
948 |
else:
|
949 |
raise NotImplementedError
|
950 |
+
return freeze(opt_state_spec), freeze(opt_state_shape)
|
951 |
|
952 |
+
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape()
|
953 |
|
954 |
# create a mesh
|
955 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|