File size: 10,402 Bytes
5a63fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
import os
import time
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
import optax
from EasyLM.jax_utils import float_to_dtype
class OptimizerFactory(object):
""" Configurable optax optimizer factory. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.accumulate_gradient_steps = 1
config.type = 'adamw'
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
config.lion_optimizer = LionOptimizerFactory.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.type == 'palm':
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
config.palm_optimizer, weight_decay_mask
)
elif config.type == 'adamw':
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
config.adamw_optimizer, weight_decay_mask
)
elif config.type == 'lion':
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
config.lion_optimizer, weight_decay_mask
)
else:
raise ValueError(f'Unknown optimizer type: {config.type}')
if config.accumulate_gradient_steps > 1:
optimizer = optax.MultiSteps(
optimizer, config.accumulate_gradient_steps
)
return optimizer, optimizer_info
class PalmOptimizerFactory(object):
""" PaLM optimizer factory. This optimizer implements the optimizer
described in the PaLM paper: https://arxiv.org/abs/2204.02311
"""
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.lr = 0.01
config.lr_warmup_steps = 10000
config.b1 = 0.9
config.b2 = 0.99
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
def learning_rate_schedule(step):
multiplier = config.lr / 0.01
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
def weight_decay_schedule(step):
multiplier = config.weight_decay / 1e-4
return -multiplier * jnp.square(learning_rate_schedule(step))
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
weight_decay_schedule=weight_decay_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
weight_decay_schedule, weight_decay_mask
)
)
return optimizer, optimizer_info
class AdamWOptimizerFactory(object):
""" AdamW optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.001
config.lr = 0.01
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.95
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
config.multiply_by_parameter_scale = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
if config.multiply_by_parameter_scale:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
lambda step: -learning_rate_schedule(step) * config.weight_decay,
weight_decay_mask
)
)
else:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adamw(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class LionOptimizerFactory(object):
""" Lion optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.0001
config.lr = 0.001
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.98
config.clip_gradient = 1.0
config.weight_decay = 1e-3
config.bf16_momentum = False
config.lr_schedule_type = "warmup_cosine_decay_schedule"
config.lr_decay_rate = 0.98
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "warmup_constant":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
],
[config.lr_warmup_steps],
)
elif config.lr_schedule_type == "exponential_decay":
learning_rate_schedule = optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
else:
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", or "exponential_decay"')
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.lion(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class OptaxScheduledWeightDecayState(NamedTuple):
count: jax.Array
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
""" Apply weight decay with schedule. """
def init_fn(params):
del params
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
def update_fn(updates, state, params):
if params is None:
raise ValueError('Params cannot be None for weight decay!')
weight_decay = schedule_fn(state.count)
updates = jax.tree_util.tree_map(
lambda g, p: g + weight_decay * p, updates, params
)
return updates, OptaxScheduledWeightDecayState(
count=optax.safe_int32_increment(state.count)
)
if mask is not None:
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
return optax.GradientTransformation(init_fn, update_fn)
|