Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
raw
history blame
8.11 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
import torch
from torch import nn
from torch.optim import Optimizer
from transformers import Trainer, TrainingArguments, get_scheduler
from swift.utils import get_logger
try:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
except ImportError:
from torch.optim.lr_scheduler import LRScheduler
logger = get_logger()
@dataclass
class GaLoreConfig:
"""
The configuration class for the Galore module.
See https://arxiv.org/abs/2403.03507
Args:
rank (`int`): The galore rank
target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
will use all attn and mlp linears
update_proj_gap(`int`): The projection update interval for galore
proj_type(`str`) The project type of Galore, valid values are `std`,
`reverse_std`, `right`, `left`, `full`
galore_scale(float): the scale of gradient
optim_per_parameter(bool): Gives one optimizer per parameter
"""
rank: int = 128
target_modules: Union[str, List[str]] = None
update_proj_gap: int = 50
galore_scale: float = 1.0
proj_type: str = 'std'
optim_per_parameter: bool = False
quantize: bool = False
proj_quant: bool = False
proj_bits: int = 4
proj_group_size: int = 256
cos_threshold: float = 0.4
gamma_proj: int = 2
queue_size: int = 5
class GaloreOptimizerWrapper(Optimizer):
def __init__(self, optimizers: Dict[Any, Optimizer]):
self.optimizers = optimizers
super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.})
def zero_grad(self, *args, **kwargs) -> None:
for optim in self.optimizers.values():
optim.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs) -> None:
for optim in self.optimizers.values():
optim.step(*args, **kwargs)
class GaloreSchedulerWrapper(LRScheduler):
def __init__(self, lr_schedulers: Dict[Any, LRScheduler]):
self.lr_schedulers = lr_schedulers
def step(self, *args, **kwargs) -> None:
for lr_scheduler in self.lr_schedulers.values():
lr_scheduler.step(*args, **kwargs)
self._last_lr = lr_scheduler.get_last_lr()
def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps,
**defaults):
galore_params = []
for module_name, module in model.named_modules():
if not isinstance(module, (nn.Linear, nn.Embedding)) or \
not any(target_key in module_name for target_key in config.target_modules):
continue
if not module.weight.requires_grad:
continue
logger.info(f'Enable GaLore for weights in module: {module_name}')
galore_params.append(module.weight)
id_galore_params = [id(p) for p in galore_params]
galore_defaults = {
'rank': config.rank,
'update_proj_gap': config.update_proj_gap,
'scale': config.galore_scale,
'proj_type': config.proj_type,
**defaults
}
if config.quantize:
galore_defaults['quant'] = config.proj_quant
galore_defaults['quant_n_bit'] = config.proj_bits
galore_defaults['quant_group_size'] = config.proj_group_size
galore_defaults['cos_threshold'] = config.cos_threshold
galore_defaults['gamma_proj'] = config.gamma_proj
galore_defaults['queue_size'] = config.queue_size
optim_cls, optim_kwargs = get_optimizer(args, config)
if config.optim_per_parameter and not config.quantize:
# q-galore does not support optim_per_parameter
optimizer_dict = {}
galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2
for p in model.parameters():
if p.requires_grad:
if id(p) in id_galore_params:
optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs)
else:
optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs)
# get scheduler dict
scheduler_dict = {}
for p in model.parameters():
if p.requires_grad:
scheduler_dict[p] = get_scheduler(
optimizer=optimizer_dict[p],
name=args.lr_scheduler_type,
num_training_steps=max_steps * 2,
num_warmup_steps=args.warmup_steps * 2,
scheduler_specific_kwargs=args.lr_scheduler_kwargs,
)
return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict)
else:
decay_parameters = Trainer.get_decay_parameter_names(Trainer, model)
param_groups = [{
'params': galore_params,
**galore_defaults,
}]
param_groups.extend([
{
'params': [
p for n, p in model.named_parameters()
if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
],
'weight_decay':
defaults['weight_decay'],
},
{
'params': [
p for n, p in model.named_parameters()
if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
],
'weight_decay':
0.0,
},
])
optim = optim_cls(param_groups, **optim_kwargs)
scheduler = get_scheduler(
optimizer=optim,
name=args.lr_scheduler_type,
num_training_steps=max_steps,
num_warmup_steps=args.warmup_steps,
scheduler_specific_kwargs=args.lr_scheduler_kwargs,
)
return optim, scheduler
def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]:
# parse args.optim_args
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(' ', '').split(','):
key, value = mapping.split('=')
optim_args[key] = value
optimizer_kwargs = {'lr': args.learning_rate}
adam_kwargs = {
'betas': (args.adam_beta1, args.adam_beta2),
'eps': args.adam_epsilon,
}
if args.optim == 'adafactor':
from .adafactor import GaLoreAdafactor
optimizer_cls = GaLoreAdafactor
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
elif args.optim in ('adamw_hf', 'adamw_torch'):
if config.quantize:
assert importlib.util.find_spec('q_galore_torch') is not None, \
'Please install q-galore by `pip install q_galore_torch`'
logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0')
from swift.utils import get_dist_setting
_, _, world_size, _ = get_dist_setting()
if world_size > 1:
# from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
else:
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
else:
from .adamw import GaLoreAdamW
optimizer_cls = GaLoreAdamW
optimizer_kwargs.update(adam_kwargs)
elif 'adamw' in args.optim and '8bit' in args.optim:
try:
from .adamw8bit import GaLoreAdamW8bit
optimizer_cls = GaLoreAdamW8bit
optimizer_kwargs.update(adam_kwargs)
optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim})
except ImportError:
raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!')
else:
raise ValueError(f'Galore not supported for optimizer type: {args.optim}')
return optimizer_cls, optimizer_kwargs