# Copyright (c) Alibaba, Inc. and its affiliates. import math import os import sys from typing import TYPE_CHECKING, List, Optional, Tuple import torch.nn as nn from peft import PeftModel from transformers import Trainer from swift.trainers.optimizers.galore import create_optimizer_and_scheduler from swift.utils import get_dist_setting if TYPE_CHECKING: from swift.trainers import TrainingArguments def calculate_max_steps(args: 'TrainingArguments', dataset) -> int: if args.max_steps and args.max_steps > 0: max_steps = args.max_steps else: len_dataset = len(dataset) _, _, world_size, _ = get_dist_setting() total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size num_update_steps_per_epoch = len_dataset // total_train_batch_size num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) return max_steps def create_galore_optimizer(args: 'TrainingArguments', model, dataset): training_steps = calculate_max_steps(args, dataset) optimizer, lr_scheduler = create_optimizer_and_scheduler( model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay) # trainer cannot serialize galore_config args.galore_config = None return optimizer, lr_scheduler def create_lorap_optimizer(args: 'TrainingArguments', model, dataset): optimizer_grouped_parameters = None if hasattr(model, 'create_optimizer_param_groups'): # Lora+ parameter groups optimizer_grouped_parameters = model.create_optimizer_param_groups( lr=args.learning_rate, weight_decay=args.weight_decay) if optimizer_grouped_parameters is None: # Default parameter groups decay_parameters = Trainer.get_decay_parameter_names(None, model) optimizer_grouped_parameters = [ { 'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)], 'weight_decay': args.weight_decay, }, { 'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)], 'weight_decay': 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None def create_muon_optimizer(args: 'TrainingArguments', model, dataset): from swift.llm import git_clone_github, get_model_arch if not args.local_repo_path: args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git') sys.path.append(os.path.join(args.local_repo_path, 'examples')) from toy_train import Muon # parse args.optim_args optim_args = {} if args.optim_args: for mapping in args.optim_args.replace(' ', '').split(','): key, value = mapping.split('=') optim_args[key] = value model_arch = get_model_arch(model.model_meta.model_arch) embed_key = getattr(model_arch, 'embedding', None) or 'embed_tokens' lm_head_key = getattr(model_arch, 'lm_head', None) or 'lm_head' muon_params = [ p for n, p in model.named_parameters() if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n ] adamw_params = [ p for n, p in model.named_parameters() if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n) ] return Muon( lr=args.learning_rate, wd=args.weight_decay, muon_params=muon_params, adamw_params=adamw_params, adamw_betas=(args.adam_beta1, args.adam_beta2), adamw_eps=args.adam_epsilon, **optim_args, ), None def get_param_startswith(model, chosen_prefix: List[str], rejected_prefix: Optional[List[str]] = None) -> List[Tuple[str, nn.Parameter]]: chosen_prefix = chosen_prefix or [] rejected_prefix = rejected_prefix or [] res = [] if not chosen_prefix: return res is_peft_model = isinstance(model, PeftModel) if is_peft_model: model = model.model for n, p in model.named_parameters(): if not p.requires_grad: continue is_rejected = False for prefix in rejected_prefix: if n.startswith(prefix): is_rejected = True break if is_rejected: continue for prefix in chosen_prefix: if n.startswith(prefix): if is_peft_model: n = f'base_model.model.{n}' res.append((n, p)) break return res def create_multimodal_optimizer(args: 'TrainingArguments', model, dataset): """ViT/Aligner/LLM use different learning rates.""" from swift.llm import get_model_arch decay_parameters = set(Trainer.get_decay_parameter_names(None, model)) model_arch = get_model_arch(model.model_meta.model_arch) vit_parameters = get_param_startswith(model, model_arch.vision_tower, model_arch.aligner) aligner_parameters = get_param_startswith(model, model_arch.aligner) llm_parameters = get_param_startswith(model, model_arch.language_model) optimizer_grouped_parameters = [] for lr, parameters in zip([args.vit_lr, args.aligner_lr, args.learning_rate], [vit_parameters, aligner_parameters, llm_parameters]): if lr is None: lr = args.learning_rate for wd in [0., args.weight_decay]: if wd == 0: params = [p for n, p in parameters if n not in decay_parameters] else: params = [p for n, p in parameters if n in decay_parameters] if not params: continue optimizer_grouped_parameters.append({ 'params': params, 'weight_decay': args.weight_decay, 'lr': lr, }) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args, model) return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None # Add your own optimizers here, use --optimizer xxx to train optimizers_map = { 'galore': create_galore_optimizer, 'lorap': create_lorap_optimizer, 'muon': create_muon_optimizer, 'multimodal': create_multimodal_optimizer, }