Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import platform
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union
from transformers.training_args import TrainingArguments as HfTrainingArguments
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
from swift.utils import get_dist_setting, get_logger, is_liger_available, is_mp, json_parse_to_dict
from .optimizers.galore import GaLoreConfig
logger = get_logger()
@dataclass
class TrainArgumentsMixin:
"""
check_model (bool): Flag to check the model is latest. Default is True.
acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'.
"""
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
gradient_accumulation_steps: Optional[int] = None
tuner_backend: Optional[str] = None
gradient_checkpointing: bool = True
vit_gradient_checkpointing: Optional[bool] = None
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
logging_first_step: bool = True
logging_steps: int = 5
weight_decay: float = 0.1
adam_beta2: float = 0.95
lr_scheduler_type: str = 'cosine'
lr_scheduler_kwargs: Optional[Union[dict, str]] = None
report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
dataloader_num_workers: Optional[int] = None
dataloader_persistent_workers: bool = False
dataloader_prefetch_factor: Optional[int] = None
use_liger_kernel: bool = False
# extra
check_model: bool = True
acc_strategy: Literal['token', 'seq'] = 'token'
train_dataloader_shuffle: bool = True
max_epochs: Optional[int] = None
aligner_lr: Optional[float] = None
vit_lr: Optional[float] = None
optimizer: Optional[str] = None
use_logits_to_keep: Optional[bool] = None
channels: List[str] = None
ds3_gather_for_generation: bool = True
resume_only_model: bool = False
# train-eval loop args
eval_use_evalscope: bool = False
eval_dataset: List[str] = field(default_factory=list)
eval_dataset_args: Optional[Union[str, dict]] = None
eval_limit: Optional[int] = None
eval_generation_config: Optional[Union[str, dict]] = None
@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
from liger_kernel.transformers.model import loss_utils
origin_LigerForCausalLMLoss = loss_utils.LigerForCausalLMLoss
def LigerForCausalLMLoss(hidden_states, *args, **kwargs):
hidden_states = hidden_states.contiguous()
return origin_LigerForCausalLMLoss(hidden_states, *args, **kwargs)
loss_utils.LigerForCausalLMLoss = LigerForCausalLMLoss
logger.info('Patch liger_kernel successfully.')
def _init_liger(self):
if self.use_liger_kernel:
assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
try:
self._patch_liger_kernel()
except Exception:
pass
def __post_init__(self):
if is_mp() and self.use_liger_kernel:
raise ValueError('liger_kernel does not support device_map. '
'Please use DDP/DeepSpeed for multi-GPU training.')
if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None):
self.optimizer = 'multimodal'
if self.gradient_accumulation_steps is None:
world_size = get_dist_setting()[2]
self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}')
if self.lr_scheduler_kwargs:
self.lr_scheduler_kwargs = json_parse_to_dict(self.lr_scheduler_kwargs)
if self.vit_gradient_checkpointing is None:
self.vit_gradient_checkpointing = self.gradient_checkpointing
if self.gradient_checkpointing_kwargs:
self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs)
self._init_liger()
if self.dataloader_num_workers is None:
if platform.system() == 'Windows':
self.dataloader_num_workers = 0
else:
self.dataloader_num_workers = 1
logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}')
if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
self.dataloader_prefetch_factor = 10
if self.eval_use_evalscope:
try:
import evalscope
except ImportError:
raise ImportError('evalscope is not installed, please install it by `pip install evalscope`')
self.eval_dataset_args = json_parse_to_dict(self.eval_dataset_args)
self.eval_generation_config = json_parse_to_dict(self.eval_generation_config)
super().__post_init__()
@dataclass
class RLHFArgumentsMixin:
# gkd
sft_alpha: float = 0
@dataclass
class SwiftArgumentsMixin(RLHFArgumentsMixin, TrainArgumentsMixin):
# Value copied from TrainArguments
train_type: Optional[str] = None
local_repo_path: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None
def __post_init__(self):
if hasattr(self, 'output_dir'):
self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir))
super().__post_init__()
@dataclass
class VllmArguments:
"""
VllmArguments is a dataclass that holds the configuration for vllm.
Args:
vllm_gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
vllm_tensor_parallel_size (int): Tensor parallelism size. Default is 1.
vllm_pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
vllm_max_num_seqs (int): Maximum number of sequences. Default is 256.
vllm_max_model_len (Optional[int]): Maximum model length. Default is None.
vllm_disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True.
vllm_enforce_eager (bool): Flag to enforce eager execution. Default is False.
vllm_limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
vllm_enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
vllm_use_async_engine (bool): Whether to use async engine for vLLM. Default is False.
vllm_quantization (Optional[str]): The quantization method for vLLM. Default is None.
vllm_data_parallel_size (int): Data parallelism size for vLLM rollout. Default is 1.
"""
# vllm
vllm_gpu_memory_utilization: float = 0.9
vllm_tensor_parallel_size: int = 1
vllm_pipeline_parallel_size: int = 1
vllm_enable_expert_parallel: bool = False
vllm_max_num_seqs: int = 256
vllm_max_model_len: Optional[int] = None
vllm_disable_custom_all_reduce: bool = True
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
vllm_max_lora_rank: int = 16
vllm_enable_prefix_caching: bool = False
vllm_use_async_engine: bool = False
vllm_quantization: Optional[str] = None
# rollout
vllm_data_parallel_size: int = 1
# compatibility (will be removed in ms-swift 3.8 and later)
gpu_memory_utilization: Optional[float] = None
tensor_parallel_size: Optional[int] = None
max_model_len: Optional[int] = None
limit_mm_per_prompt: Optional[Union[dict, str]] = None
data_parallel_size: Optional[int] = None
use_async_engine: Optional[bool] = None
def _handle_compatibility(self):
if self.gpu_memory_utilization is not None:
self.vllm_gpu_memory_utilization = self.gpu_memory_utilization
if self.tensor_parallel_size is not None:
self.vllm_tensor_parallel_size = self.tensor_parallel_size
if self.max_model_len is not None:
self.vllm_max_model_len = self.max_model_len
if self.limit_mm_per_prompt is not None:
self.vllm_limit_mm_per_prompt = self.limit_mm_per_prompt
if self.data_parallel_size is not None:
self.vllm_data_parallel_size = self.data_parallel_size
if self.use_async_engine is not None:
self.vllm_use_async_engine = self.use_async_engine
def __post_init__(self):
self._handle_compatibility()
self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt)
def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
kwargs = {
'gpu_memory_utilization': self.vllm_gpu_memory_utilization,
'tensor_parallel_size': self.vllm_tensor_parallel_size,
'pipeline_parallel_size': self.vllm_pipeline_parallel_size,
'enable_expert_parallel': self.vllm_enable_expert_parallel,
'max_num_seqs': self.vllm_max_num_seqs,
'max_model_len': self.vllm_max_model_len,
'disable_custom_all_reduce': self.vllm_disable_custom_all_reduce,
'enforce_eager': self.vllm_enforce_eager,
'limit_mm_per_prompt': self.vllm_limit_mm_per_prompt,
'max_lora_rank': self.vllm_max_lora_rank,
'enable_lora': len(adapters) > 0,
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.vllm_enable_prefix_caching,
'use_async_engine': self.vllm_use_async_engine,
'quantization': self.vllm_quantization,
}
if self.task_type == 'embedding':
kwargs['task_type'] = 'embed'
return kwargs
@dataclass
class GRPOArgumentsMixin(VllmArguments):
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.
# vllm
vllm_mode: Literal['server', 'colocate'] = 'colocate'
# internal vllm (colocate)
vllm_enable_prefix_caching: bool = True # overwrite
# external vllm (server)
vllm_server_base_url: Optional[List[str]] = None
vllm_server_host: Optional[List[str]] = None
vllm_server_port: List[int] = field(default_factory=lambda: [8000])
vllm_server_timeout: float = 240.0
vllm_client = None # Not required to set, used for client instantiation
# reward function args, see details in swift/plugin/orm.py
# cosine reward, https://arxiv.org/abs/2502.03373
cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
# repetition penalty, https://arxiv.org/abs/2502.03373
repetition_n_grams: int = 3
repetition_max_penalty: float = -1.0
reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None
# sync ref model
sync_ref_model: bool = False
ref_model_sync_steps: int = 512
ref_model_mixup_alpha: float = 0.6
async_generate: bool = False
sleep_level: int = 0
move_model_batches: Optional[int] = None
offload_optimizer: bool = False
offload_model: bool = False
gc_collect_after_offload: bool = False # deprecated
# multi turn
multi_turn_func: Optional[str] = None # deprecated
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None
completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
# DAPO, https://arxiv.org/abs/2503.14476
dynamic_sample: bool = False
max_resample_times: int = 3
overlong_filter: bool = False
soft_max_length: Optional[int] = None
soft_cache_length: Optional[int] = None
# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: bool = True
# entropy
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0
# GSPO https://www.arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence'] = 'token'
wandb_log_unique_prompts: Optional[bool] = None
generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
# dataset
dataset_shuffle: Optional[bool] = True
@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
pass
@dataclass
class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments):
pass