Ahma-7B / EasyLM /optimizers.py
aapot
Add training codes
a85f909
raw history blame
No virus
10.4 kB
import os
import time
from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
from functools import partial
import re
import dataclasses
import random
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging
import optax
from EasyLM.jax_utils import float_to_dtype
class OptimizerFactory(object):
""" Configurable optax optimizer factory. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.accumulate_gradient_steps = 1
config.type = 'adamw'
config.palm_optimizer = PalmOptimizerFactory.get_default_config()
config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
config.lion_optimizer = LionOptimizerFactory.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.type == 'palm':
optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
config.palm_optimizer, weight_decay_mask
)
elif config.type == 'adamw':
optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
config.adamw_optimizer, weight_decay_mask
)
elif config.type == 'lion':
optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
config.lion_optimizer, weight_decay_mask
)
else:
raise ValueError(f'Unknown optimizer type: {config.type}')
if config.accumulate_gradient_steps > 1:
optimizer = optax.MultiSteps(
optimizer, config.accumulate_gradient_steps
)
return optimizer, optimizer_info
class PalmOptimizerFactory(object):
""" PaLM optimizer factory. This optimizer implements the optimizer
described in the PaLM paper: https://arxiv.org/abs/2204.02311
"""
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.lr = 0.01
config.lr_warmup_steps = 10000
config.b1 = 0.9
config.b2 = 0.99
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
def learning_rate_schedule(step):
multiplier = config.lr / 0.01
return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
def weight_decay_schedule(step):
multiplier = config.weight_decay / 1e-4
return -multiplier * jnp.square(learning_rate_schedule(step))
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
weight_decay_schedule=weight_decay_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
weight_decay_schedule, weight_decay_mask
)
)
return optimizer, optimizer_info
class AdamWOptimizerFactory(object):
""" AdamW optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.001
config.lr = 0.01
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.95
config.clip_gradient = 1.0
config.weight_decay = 1e-4
config.bf16_momentum = False
config.multiply_by_parameter_scale = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
if config.multiply_by_parameter_scale:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adafactor(
learning_rate=learning_rate_schedule,
multiply_by_parameter_scale=True,
momentum=config.b1,
decay_rate=config.b2,
factored=False,
clipping_threshold=None,
dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
optax_add_scheduled_weight_decay(
lambda step: -learning_rate_schedule(step) * config.weight_decay,
weight_decay_mask
)
)
else:
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.adamw(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class LionOptimizerFactory(object):
""" Lion optimizer with cosine schedule. """
def __init__(self):
raise NotImplementedError
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.init_lr = 0.0
config.end_lr = 0.0001
config.lr = 0.001
config.lr_warmup_steps = 2000
config.lr_decay_steps = 500000
config.b1 = 0.9
config.b2 = 0.98
config.clip_gradient = 1.0
config.weight_decay = 1e-3
config.bf16_momentum = False
config.lr_schedule_type = "warmup_cosine_decay_schedule"
config.lr_decay_rate = 0.98
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def get_optimizer(cls, config, weight_decay_mask=None):
config = cls.get_default_config(config)
if config.lr_schedule_type == "warmup_cosine_decay_schedule":
learning_rate_schedule = optax.warmup_cosine_decay_schedule(
init_value=config.init_lr,
peak_value=config.lr,
warmup_steps=config.lr_warmup_steps,
decay_steps=config.lr_decay_steps,
end_value=config.end_lr,
)
elif config.lr_schedule_type == "warmup_constant":
learning_rate_schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=config.init_lr,
end_value=config.lr,
transition_steps=config.lr_warmup_steps,
),
optax.constant_schedule(config.lr),
],
[config.lr_warmup_steps],
)
elif config.lr_schedule_type == "exponential_decay":
learning_rate_schedule = optax.exponential_decay(
init_value=config.lr,
transition_steps=config.lr_decay_steps,
decay_rate=config.lr_decay_rate,
transition_begin=0,
staircase=False,
end_value=config.end_lr,
)
else:
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", or "exponential_decay"')
optimizer_info = dict(
learning_rate_schedule=learning_rate_schedule,
)
optimizer = optax.chain(
optax.clip_by_global_norm(config.clip_gradient),
optax.lion(
learning_rate=learning_rate_schedule,
weight_decay=config.weight_decay,
b1=config.b1,
b2=config.b2,
mask=weight_decay_mask,
mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
),
)
return optimizer, optimizer_info
class OptaxScheduledWeightDecayState(NamedTuple):
count: jax.Array
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
""" Apply weight decay with schedule. """
def init_fn(params):
del params
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
def update_fn(updates, state, params):
if params is None:
raise ValueError('Params cannot be None for weight decay!')
weight_decay = schedule_fn(state.count)
updates = jax.tree_util.tree_map(
lambda g, p: g + weight_decay * p, updates, params
)
return updates, OptaxScheduledWeightDecayState(
count=optax.safe_int32_increment(state.count)
)
if mask is not None:
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
return optax.GradientTransformation(init_fn, update_fn)