sparse / ms-swift /swift /trainers /rlhf_arguments.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
from dataclasses import dataclass, field
from typing import List
from trl import CPOConfig as HfCPOConfig
from trl import DPOConfig as HfDPOConfig
from trl import GKDConfig as HfGKDConfig
from trl import GRPOConfig as HfGRPOConfig
from trl import KTOConfig as HfKTOConfig
from trl import ORPOConfig as HfORPOConfig
from trl import PPOConfig as HfPPOConfig
from trl import RewardConfig as HfRewardConfig
from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin
@dataclass
class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
pass
@dataclass
class CPOConfig(SwiftArgumentsMixin, HfCPOConfig):
pass
@dataclass
class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig):
pass
@dataclass
class KTOConfig(SwiftArgumentsMixin, HfKTOConfig):
pass
@dataclass
class RewardConfig(SwiftArgumentsMixin, HfRewardConfig):
pass
@dataclass
class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
pass
@dataclass
class GKDConfig(SwiftArgumentsMixin, HfGKDConfig):
pass
@dataclass
class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
stop_words: List[str] = field(default_factory=list)
def __post_init__(self):
GRPOArgumentsMixin.__post_init__(self)
SwiftArgumentsMixin.__post_init__(self)
if self.cosine_max_len is None:
self.cosine_max_len = self.max_completion_length
if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][
'stage'] == 3:
# https://github.com/modelscope/ms-swift/issues/3237
self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0
self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0
# https://github.com/modelscope/ms-swift/issues/3863
self.dataloader_drop_last = True
num_processes = self.world_size
if self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
if self.generation_batch_size is None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
self.check_num_generations()
def check_num_generations(self):
# check num_generations for trl < 0.18
num_processes = self.world_size
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size '
f'({self.per_device_train_batch_size * num_processes}).')
self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes)
# Check if the effective batch size can be divided by the number of generations
if self.num_generations < 2:
raise ValueError(
'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
f'{self.num_generations}, which is less than the minimum required.')
possible_values = [
n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f'The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x '
f'{self.steps_per_generation}) must be evenly divisible by the number of generations per '
f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for '
f'the number of generations are: {possible_values}.')
if self.eval_strategy != 'no':
global_eval_batch_size = self.per_device_eval_batch_size * num_processes
possible_values = [
n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f'The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be '
f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the '
'current global eval batch size, the valid values for the number of generations are: '
f'{possible_values}.')