from dataclasses import dataclass, field from typing import List from trl import CPOConfig as HfCPOConfig from trl import DPOConfig as HfDPOConfig from trl import GKDConfig as HfGKDConfig from trl import GRPOConfig as HfGRPOConfig from trl import KTOConfig as HfKTOConfig from trl import ORPOConfig as HfORPOConfig from trl import PPOConfig as HfPPOConfig from trl import RewardConfig as HfRewardConfig from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin @dataclass class DPOConfig(SwiftArgumentsMixin, HfDPOConfig): pass @dataclass class CPOConfig(SwiftArgumentsMixin, HfCPOConfig): pass @dataclass class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig): pass @dataclass class KTOConfig(SwiftArgumentsMixin, HfKTOConfig): pass @dataclass class RewardConfig(SwiftArgumentsMixin, HfRewardConfig): pass @dataclass class PPOConfig(SwiftArgumentsMixin, HfPPOConfig): pass @dataclass class GKDConfig(SwiftArgumentsMixin, HfGKDConfig): pass @dataclass class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): stop_words: List[str] = field(default_factory=list) def __post_init__(self): GRPOArgumentsMixin.__post_init__(self) SwiftArgumentsMixin.__post_init__(self) if self.cosine_max_len is None: self.cosine_max_len = self.max_completion_length if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][ 'stage'] == 3: # https://github.com/modelscope/ms-swift/issues/3237 self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0 self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0 # https://github.com/modelscope/ms-swift/issues/3863 self.dataloader_drop_last = True num_processes = self.world_size if self.steps_per_generation is None: self.steps_per_generation = self.gradient_accumulation_steps if self.generation_batch_size is None: self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation self.check_num_generations() def check_num_generations(self): # check num_generations for trl < 0.18 num_processes = self.world_size if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: raise ValueError( f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size ' f'({self.per_device_train_batch_size * num_processes}).') self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes) # Check if the effective batch size can be divided by the number of generations if self.num_generations < 2: raise ValueError( 'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided ' f'{self.num_generations}, which is less than the minimum required.') possible_values = [ n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0 ] if self.num_generations not in possible_values: raise ValueError( f'The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x ' f'{self.steps_per_generation}) must be evenly divisible by the number of generations per ' f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for ' f'the number of generations are: {possible_values}.') if self.eval_strategy != 'no': global_eval_batch_size = self.per_device_eval_batch_size * num_processes possible_values = [ n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0 ] if self.num_generations not in possible_values: raise ValueError( f'The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be ' f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the ' 'current global eval batch size, the valid values for the number of generations are: ' f'{possible_values}.')