|
|
|
import importlib |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.optim import Optimizer |
|
from transformers import Trainer, TrainingArguments, get_scheduler |
|
|
|
from swift.utils import get_logger |
|
|
|
try: |
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
|
except ImportError: |
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class GaLoreConfig: |
|
""" |
|
The configuration class for the Galore module. |
|
|
|
|
|
See https://arxiv.org/abs/2403.03507 |
|
|
|
Args: |
|
rank (`int`): The galore rank |
|
target_modules (`Union[str, List[str]]`): The target modules to use, if `None`, |
|
will use all attn and mlp linears |
|
update_proj_gap(`int`): The projection update interval for galore |
|
proj_type(`str`) The project type of Galore, valid values are `std`, |
|
`reverse_std`, `right`, `left`, `full` |
|
galore_scale(float): the scale of gradient |
|
optim_per_parameter(bool): Gives one optimizer per parameter |
|
""" |
|
rank: int = 128 |
|
target_modules: Union[str, List[str]] = None |
|
update_proj_gap: int = 50 |
|
galore_scale: float = 1.0 |
|
proj_type: str = 'std' |
|
optim_per_parameter: bool = False |
|
quantize: bool = False |
|
proj_quant: bool = False |
|
proj_bits: int = 4 |
|
proj_group_size: int = 256 |
|
cos_threshold: float = 0.4 |
|
gamma_proj: int = 2 |
|
queue_size: int = 5 |
|
|
|
|
|
class GaloreOptimizerWrapper(Optimizer): |
|
|
|
def __init__(self, optimizers: Dict[Any, Optimizer]): |
|
self.optimizers = optimizers |
|
super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.}) |
|
|
|
def zero_grad(self, *args, **kwargs) -> None: |
|
for optim in self.optimizers.values(): |
|
optim.zero_grad(*args, **kwargs) |
|
|
|
def step(self, *args, **kwargs) -> None: |
|
for optim in self.optimizers.values(): |
|
optim.step(*args, **kwargs) |
|
|
|
|
|
class GaloreSchedulerWrapper(LRScheduler): |
|
|
|
def __init__(self, lr_schedulers: Dict[Any, LRScheduler]): |
|
self.lr_schedulers = lr_schedulers |
|
|
|
def step(self, *args, **kwargs) -> None: |
|
for lr_scheduler in self.lr_schedulers.values(): |
|
lr_scheduler.step(*args, **kwargs) |
|
self._last_lr = lr_scheduler.get_last_lr() |
|
|
|
|
|
def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps, |
|
**defaults): |
|
galore_params = [] |
|
for module_name, module in model.named_modules(): |
|
if not isinstance(module, (nn.Linear, nn.Embedding)) or \ |
|
not any(target_key in module_name for target_key in config.target_modules): |
|
continue |
|
|
|
if not module.weight.requires_grad: |
|
continue |
|
|
|
logger.info(f'Enable GaLore for weights in module: {module_name}') |
|
galore_params.append(module.weight) |
|
|
|
id_galore_params = [id(p) for p in galore_params] |
|
galore_defaults = { |
|
'rank': config.rank, |
|
'update_proj_gap': config.update_proj_gap, |
|
'scale': config.galore_scale, |
|
'proj_type': config.proj_type, |
|
**defaults |
|
} |
|
if config.quantize: |
|
galore_defaults['quant'] = config.proj_quant |
|
galore_defaults['quant_n_bit'] = config.proj_bits |
|
galore_defaults['quant_group_size'] = config.proj_group_size |
|
galore_defaults['cos_threshold'] = config.cos_threshold |
|
galore_defaults['gamma_proj'] = config.gamma_proj |
|
galore_defaults['queue_size'] = config.queue_size |
|
optim_cls, optim_kwargs = get_optimizer(args, config) |
|
|
|
if config.optim_per_parameter and not config.quantize: |
|
|
|
optimizer_dict = {} |
|
galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2 |
|
for p in model.parameters(): |
|
if p.requires_grad: |
|
if id(p) in id_galore_params: |
|
optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs) |
|
else: |
|
optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs) |
|
|
|
|
|
scheduler_dict = {} |
|
for p in model.parameters(): |
|
if p.requires_grad: |
|
scheduler_dict[p] = get_scheduler( |
|
optimizer=optimizer_dict[p], |
|
name=args.lr_scheduler_type, |
|
num_training_steps=max_steps * 2, |
|
num_warmup_steps=args.warmup_steps * 2, |
|
scheduler_specific_kwargs=args.lr_scheduler_kwargs, |
|
) |
|
|
|
return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict) |
|
else: |
|
decay_parameters = Trainer.get_decay_parameter_names(Trainer, model) |
|
param_groups = [{ |
|
'params': galore_params, |
|
**galore_defaults, |
|
}] |
|
param_groups.extend([ |
|
{ |
|
'params': [ |
|
p for n, p in model.named_parameters() |
|
if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad) |
|
], |
|
'weight_decay': |
|
defaults['weight_decay'], |
|
}, |
|
{ |
|
'params': [ |
|
p for n, p in model.named_parameters() |
|
if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad) |
|
], |
|
'weight_decay': |
|
0.0, |
|
}, |
|
]) |
|
optim = optim_cls(param_groups, **optim_kwargs) |
|
scheduler = get_scheduler( |
|
optimizer=optim, |
|
name=args.lr_scheduler_type, |
|
num_training_steps=max_steps, |
|
num_warmup_steps=args.warmup_steps, |
|
scheduler_specific_kwargs=args.lr_scheduler_kwargs, |
|
) |
|
return optim, scheduler |
|
|
|
|
|
def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]: |
|
|
|
optim_args = {} |
|
if args.optim_args: |
|
for mapping in args.optim_args.replace(' ', '').split(','): |
|
key, value = mapping.split('=') |
|
optim_args[key] = value |
|
|
|
optimizer_kwargs = {'lr': args.learning_rate} |
|
|
|
adam_kwargs = { |
|
'betas': (args.adam_beta1, args.adam_beta2), |
|
'eps': args.adam_epsilon, |
|
} |
|
if args.optim == 'adafactor': |
|
from .adafactor import GaLoreAdafactor |
|
optimizer_cls = GaLoreAdafactor |
|
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False}) |
|
elif args.optim in ('adamw_hf', 'adamw_torch'): |
|
if config.quantize: |
|
assert importlib.util.find_spec('q_galore_torch') is not None, \ |
|
'Please install q-galore by `pip install q_galore_torch`' |
|
logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0') |
|
from swift.utils import get_dist_setting |
|
_, _, world_size, _ = get_dist_setting() |
|
if world_size > 1: |
|
|
|
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW |
|
else: |
|
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW |
|
else: |
|
from .adamw import GaLoreAdamW |
|
optimizer_cls = GaLoreAdamW |
|
optimizer_kwargs.update(adam_kwargs) |
|
elif 'adamw' in args.optim and '8bit' in args.optim: |
|
try: |
|
from .adamw8bit import GaLoreAdamW8bit |
|
optimizer_cls = GaLoreAdamW8bit |
|
optimizer_kwargs.update(adam_kwargs) |
|
optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim}) |
|
except ImportError: |
|
raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!') |
|
else: |
|
raise ValueError(f'Galore not supported for optimizer type: {args.optim}') |
|
return optimizer_cls, optimizer_kwargs |
|
|