Spaces:
Running
Running
| """ | |
| SmolLM3 DPO Training Configuration | |
| Based on nanoGPT structure but adapted for SmolLM3 DPO training | |
| """ | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| from config.train_smollm3 import SmolLM3Config | |
| class SmolLM3DPOConfig(SmolLM3Config): | |
| """Configuration for SmolLM3 DPO fine-tuning""" | |
| # Trainer type selection | |
| trainer_type: str = "dpo" # Override default to use DPO trainer | |
| # DPO-specific configuration | |
| beta: float = 0.1 | |
| max_prompt_length: int = 2048 | |
| max_length: int = 4096 | |
| # DPO training configuration | |
| dpo_beta: float = 0.1 | |
| dpo_loss_type: str = "sigmoid" # "sigmoid" or "hinge" | |
| dpo_alpha: float = 0.5 | |
| # Reference model configuration | |
| ref_model_name: Optional[str] = None # If None, will use the same as model_name | |
| ref_model_peft_config: Optional[dict] = None | |
| # Preference dataset configuration | |
| preference_dataset_format: str = "dpo" # "dpo", "rlhf", "custom" | |
| preference_dataset_text_field: str = "text" | |
| preference_dataset_prompt_field: str = "prompt" | |
| preference_dataset_chosen_field: str = "chosen" | |
| preference_dataset_rejected_field: str = "rejected" | |
| # DPO training arguments | |
| dpo_gradient_checkpointing: bool = True | |
| dpo_gradient_checkpointing_kwargs: dict = None | |
| dpo_precompute_ref_log_probs: bool = False | |
| dpo_peft_config: Optional[dict] = None | |
| def __post_init__(self): | |
| super().__post_init__() | |
| # Set default values for DPO-specific settings | |
| if self.ref_model_name is None: | |
| self.ref_model_name = self.model_name | |
| if self.dpo_gradient_checkpointing_kwargs is None: | |
| self.dpo_gradient_checkpointing_kwargs = { | |
| "use_reentrant": False | |
| } | |
| if self.dpo_peft_config is None: | |
| self.dpo_peft_config = { | |
| "r": 16, | |
| "lora_alpha": 32, | |
| "lora_dropout": 0.1, | |
| "bias": "none", | |
| "task_type": "CAUSAL_LM" | |
| } | |
| # Validate DPO configuration | |
| if self.beta <= 0: | |
| raise ValueError("beta must be positive") | |
| if self.max_prompt_length > self.max_seq_length: | |
| raise ValueError("max_prompt_length cannot exceed max_seq_length") | |
| if self.max_length > self.max_seq_length: | |
| raise ValueError("max_length cannot exceed max_seq_length") | |
| def get_dpo_config(config_path: str) -> SmolLM3DPOConfig: | |
| """Load DPO configuration from file or return default""" | |
| if os.path.exists(config_path): | |
| # Load from file if it exists | |
| import importlib.util | |
| spec = importlib.util.spec_from_file_location("config_module", config_path) | |
| config_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(config_module) | |
| if hasattr(config_module, 'config'): | |
| return config_module.config | |
| else: | |
| # Try to find a config class | |
| for attr_name in dir(config_module): | |
| attr = getattr(config_module, attr_name) | |
| if isinstance(attr, SmolLM3DPOConfig): | |
| return attr | |
| # Return default configuration | |
| return SmolLM3DPOConfig() | |
| # Default DPO configuration instance | |
| config = SmolLM3DPOConfig() |