|
|
|
|
|
|
|
|
import torch |
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): |
|
|
|
|
|
supported_optimizers = ["adam", "adamw", "prodigy"] |
|
|
if args.optimizer not in supported_optimizers: |
|
|
logger.warning( |
|
|
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" |
|
|
) |
|
|
args.optimizer = "adamw" |
|
|
|
|
|
if args.use_8bit_adam and not (args.optimizer.lower() |
|
|
not in ["adam", "adamw"]): |
|
|
logger.warning( |
|
|
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " |
|
|
f"set to {args.optimizer.lower()}") |
|
|
|
|
|
if args.use_8bit_adam: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
|
|
) |
|
|
|
|
|
if args.optimizer.lower() == "adamw": |
|
|
optimizer_class = (bnb.optim.AdamW8bit |
|
|
if args.use_8bit_adam else torch.optim.AdamW) |
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
) |
|
|
elif args.optimizer.lower() == "adam": |
|
|
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam |
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
) |
|
|
elif args.optimizer.lower() == "prodigy": |
|
|
try: |
|
|
import prodigyopt |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" |
|
|
) |
|
|
|
|
|
optimizer_class = prodigyopt.Prodigy |
|
|
|
|
|
if args.learning_rate <= 0.1: |
|
|
logger.warning( |
|
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" |
|
|
) |
|
|
|
|
|
optimizer = optimizer_class( |
|
|
params_to_optimize, |
|
|
lr=args.learning_rate, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
beta3=args.prodigy_beta3, |
|
|
weight_decay=args.adam_weight_decay, |
|
|
eps=args.adam_epsilon, |
|
|
decouple=args.prodigy_decouple, |
|
|
use_bias_correction=args.prodigy_use_bias_correction, |
|
|
safeguard_warmup=args.prodigy_safeguard_warmup, |
|
|
) |
|
|
|
|
|
return optimizer |
|
|
|