egrpo / fastvideo /utils /optimizer.py
studyOverflow's picture
Add files using upload-large-folder tool
b171568 verified
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
import torch
from accelerate.logging import get_logger
logger = get_logger(__name__)
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower()
not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}")
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if args.optimizer.lower() == "adamw":
optimizer_class = (bnb.optim.AdamW8bit
if args.use_8bit_adam else torch.optim.AdamW)
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError(
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
)
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer