File size: 13,208 Bytes
96fe658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import platform
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union
from transformers.training_args import TrainingArguments as HfTrainingArguments
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
from swift.utils import get_dist_setting, get_logger, is_liger_available, is_mp, json_parse_to_dict
from .optimizers.galore import GaLoreConfig
logger = get_logger()
@dataclass
class TrainArgumentsMixin:
"""
check_model (bool): Flag to check the model is latest. Default is True.
acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'.
"""
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
gradient_accumulation_steps: Optional[int] = None
tuner_backend: Optional[str] = None
gradient_checkpointing: bool = True
vit_gradient_checkpointing: Optional[bool] = None
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
logging_first_step: bool = True
logging_steps: int = 5
weight_decay: float = 0.1
adam_beta2: float = 0.95
lr_scheduler_type: str = 'cosine'
lr_scheduler_kwargs: Optional[Union[dict, str]] = None
report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
dataloader_num_workers: Optional[int] = None
dataloader_persistent_workers: bool = False
dataloader_prefetch_factor: Optional[int] = None
use_liger_kernel: bool = False
# extra
check_model: bool = True
acc_strategy: Literal['token', 'seq'] = 'token'
train_dataloader_shuffle: bool = True
max_epochs: Optional[int] = None
aligner_lr: Optional[float] = None
vit_lr: Optional[float] = None
optimizer: Optional[str] = None
use_logits_to_keep: Optional[bool] = None
channels: List[str] = None
ds3_gather_for_generation: bool = True
resume_only_model: bool = False
# train-eval loop args
eval_use_evalscope: bool = False
eval_dataset: List[str] = field(default_factory=list)
eval_dataset_args: Optional[Union[str, dict]] = None
eval_limit: Optional[int] = None
eval_generation_config: Optional[Union[str, dict]] = None
@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
from liger_kernel.transformers.model import loss_utils
origin_LigerForCausalLMLoss = loss_utils.LigerForCausalLMLoss
def LigerForCausalLMLoss(hidden_states, *args, **kwargs):
hidden_states = hidden_states.contiguous()
return origin_LigerForCausalLMLoss(hidden_states, *args, **kwargs)
loss_utils.LigerForCausalLMLoss = LigerForCausalLMLoss
logger.info('Patch liger_kernel successfully.')
def _init_liger(self):
if self.use_liger_kernel:
assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
try:
self._patch_liger_kernel()
except Exception:
pass
def __post_init__(self):
if is_mp() and self.use_liger_kernel:
raise ValueError('liger_kernel does not support device_map. '
'Please use DDP/DeepSpeed for multi-GPU training.')
if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None):
self.optimizer = 'multimodal'
if self.gradient_accumulation_steps is None:
world_size = get_dist_setting()[2]
self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}')
if self.lr_scheduler_kwargs:
self.lr_scheduler_kwargs = json_parse_to_dict(self.lr_scheduler_kwargs)
if self.vit_gradient_checkpointing is None:
self.vit_gradient_checkpointing = self.gradient_checkpointing
if self.gradient_checkpointing_kwargs:
self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs)
self._init_liger()
if self.dataloader_num_workers is None:
if platform.system() == 'Windows':
self.dataloader_num_workers = 0
else:
self.dataloader_num_workers = 1
logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}')
if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
self.dataloader_prefetch_factor = 10
if self.eval_use_evalscope:
try:
import evalscope
except ImportError:
raise ImportError('evalscope is not installed, please install it by `pip install evalscope`')
self.eval_dataset_args = json_parse_to_dict(self.eval_dataset_args)
self.eval_generation_config = json_parse_to_dict(self.eval_generation_config)
super().__post_init__()
@dataclass
class RLHFArgumentsMixin:
# gkd
sft_alpha: float = 0
@dataclass
class SwiftArgumentsMixin(RLHFArgumentsMixin, TrainArgumentsMixin):
# Value copied from TrainArguments
train_type: Optional[str] = None
local_repo_path: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None
def __post_init__(self):
if hasattr(self, 'output_dir'):
self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir))
super().__post_init__()
@dataclass
class VllmArguments:
"""
VllmArguments is a dataclass that holds the configuration for vllm.
Args:
vllm_gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
vllm_tensor_parallel_size (int): Tensor parallelism size. Default is 1.
vllm_pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
vllm_max_num_seqs (int): Maximum number of sequences. Default is 256.
vllm_max_model_len (Optional[int]): Maximum model length. Default is None.
vllm_disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True.
vllm_enforce_eager (bool): Flag to enforce eager execution. Default is False.
vllm_limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
vllm_enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
vllm_use_async_engine (bool): Whether to use async engine for vLLM. Default is False.
vllm_quantization (Optional[str]): The quantization method for vLLM. Default is None.
vllm_data_parallel_size (int): Data parallelism size for vLLM rollout. Default is 1.
"""
# vllm
vllm_gpu_memory_utilization: float = 0.9
vllm_tensor_parallel_size: int = 1
vllm_pipeline_parallel_size: int = 1
vllm_enable_expert_parallel: bool = False
vllm_max_num_seqs: int = 256
vllm_max_model_len: Optional[int] = None
vllm_disable_custom_all_reduce: bool = True
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
vllm_max_lora_rank: int = 16
vllm_enable_prefix_caching: bool = False
vllm_use_async_engine: bool = False
vllm_quantization: Optional[str] = None
# rollout
vllm_data_parallel_size: int = 1
# compatibility (will be removed in ms-swift 3.8 and later)
gpu_memory_utilization: Optional[float] = None
tensor_parallel_size: Optional[int] = None
max_model_len: Optional[int] = None
limit_mm_per_prompt: Optional[Union[dict, str]] = None
data_parallel_size: Optional[int] = None
use_async_engine: Optional[bool] = None
def _handle_compatibility(self):
if self.gpu_memory_utilization is not None:
self.vllm_gpu_memory_utilization = self.gpu_memory_utilization
if self.tensor_parallel_size is not None:
self.vllm_tensor_parallel_size = self.tensor_parallel_size
if self.max_model_len is not None:
self.vllm_max_model_len = self.max_model_len
if self.limit_mm_per_prompt is not None:
self.vllm_limit_mm_per_prompt = self.limit_mm_per_prompt
if self.data_parallel_size is not None:
self.vllm_data_parallel_size = self.data_parallel_size
if self.use_async_engine is not None:
self.vllm_use_async_engine = self.use_async_engine
def __post_init__(self):
self._handle_compatibility()
self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt)
def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
kwargs = {
'gpu_memory_utilization': self.vllm_gpu_memory_utilization,
'tensor_parallel_size': self.vllm_tensor_parallel_size,
'pipeline_parallel_size': self.vllm_pipeline_parallel_size,
'enable_expert_parallel': self.vllm_enable_expert_parallel,
'max_num_seqs': self.vllm_max_num_seqs,
'max_model_len': self.vllm_max_model_len,
'disable_custom_all_reduce': self.vllm_disable_custom_all_reduce,
'enforce_eager': self.vllm_enforce_eager,
'limit_mm_per_prompt': self.vllm_limit_mm_per_prompt,
'max_lora_rank': self.vllm_max_lora_rank,
'enable_lora': len(adapters) > 0,
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.vllm_enable_prefix_caching,
'use_async_engine': self.vllm_use_async_engine,
'quantization': self.vllm_quantization,
}
if self.task_type == 'embedding':
kwargs['task_type'] = 'embed'
return kwargs
@dataclass
class GRPOArgumentsMixin(VllmArguments):
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.
# vllm
vllm_mode: Literal['server', 'colocate'] = 'colocate'
# internal vllm (colocate)
vllm_enable_prefix_caching: bool = True # overwrite
# external vllm (server)
vllm_server_base_url: Optional[List[str]] = None
vllm_server_host: Optional[List[str]] = None
vllm_server_port: List[int] = field(default_factory=lambda: [8000])
vllm_server_timeout: float = 240.0
vllm_client = None # Not required to set, used for client instantiation
# reward function args, see details in swift/plugin/orm.py
# cosine reward, https://arxiv.org/abs/2502.03373
cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
# repetition penalty, https://arxiv.org/abs/2502.03373
repetition_n_grams: int = 3
repetition_max_penalty: float = -1.0
reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None
# sync ref model
sync_ref_model: bool = False
ref_model_sync_steps: int = 512
ref_model_mixup_alpha: float = 0.6
async_generate: bool = False
sleep_level: int = 0
move_model_batches: Optional[int] = None
offload_optimizer: bool = False
offload_model: bool = False
gc_collect_after_offload: bool = False # deprecated
# multi turn
multi_turn_func: Optional[str] = None # deprecated
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None
completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
# DAPO, https://arxiv.org/abs/2503.14476
dynamic_sample: bool = False
max_resample_times: int = 3
overlong_filter: bool = False
soft_max_length: Optional[int] = None
soft_cache_length: Optional[int] = None
# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: bool = True
# entropy
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0
# GSPO https://www.arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence'] = 'token'
wandb_log_unique_prompts: Optional[bool] = None
generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
# dataset
dataset_shuffle: Optional[bool] = True
@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
pass
@dataclass
class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments):
pass
|