Spaces:
Running
Running
File size: 3,384 Bytes
d8dd7a1 5fe83da d8dd7a1 5fe83da d8dd7a1 5fe83da d8dd7a1 40fd629 5fe83da d8dd7a1 5fe83da d8dd7a1 5fe83da d8dd7a1 5fe83da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
"""
SmolLM3 DPO Training Configuration
Based on nanoGPT structure but adapted for SmolLM3 DPO training
"""
import os
from dataclasses import dataclass
from typing import Optional
from config.train_smollm3 import SmolLM3Config
@dataclass
class SmolLM3DPOConfig(SmolLM3Config):
"""Configuration for SmolLM3 DPO fine-tuning"""
# Trainer type selection
trainer_type: str = "dpo" # Override default to use DPO trainer
# DPO-specific configuration
beta: float = 0.1
max_prompt_length: int = 2048
max_length: int = 4096
# DPO training configuration
dpo_beta: float = 0.1
dpo_loss_type: str = "sigmoid" # "sigmoid" or "hinge"
dpo_alpha: float = 0.5
# Reference model configuration
ref_model_name: Optional[str] = None # If None, will use the same as model_name
ref_model_peft_config: Optional[dict] = None
# Preference dataset configuration
preference_dataset_format: str = "dpo" # "dpo", "rlhf", "custom"
preference_dataset_text_field: str = "text"
preference_dataset_prompt_field: str = "prompt"
preference_dataset_chosen_field: str = "chosen"
preference_dataset_rejected_field: str = "rejected"
# DPO training arguments
dpo_gradient_checkpointing: bool = True
dpo_gradient_checkpointing_kwargs: dict = None
dpo_precompute_ref_log_probs: bool = False
dpo_peft_config: Optional[dict] = None
def __post_init__(self):
super().__post_init__()
# Set default values for DPO-specific settings
if self.ref_model_name is None:
self.ref_model_name = self.model_name
if self.dpo_gradient_checkpointing_kwargs is None:
self.dpo_gradient_checkpointing_kwargs = {
"use_reentrant": False
}
if self.dpo_peft_config is None:
self.dpo_peft_config = {
"r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"bias": "none",
"task_type": "CAUSAL_LM"
}
# Validate DPO configuration
if self.beta <= 0:
raise ValueError("beta must be positive")
if self.max_prompt_length > self.max_seq_length:
raise ValueError("max_prompt_length cannot exceed max_seq_length")
if self.max_length > self.max_seq_length:
raise ValueError("max_length cannot exceed max_seq_length")
def get_dpo_config(config_path: str) -> SmolLM3DPOConfig:
"""Load DPO configuration from file or return default"""
if os.path.exists(config_path):
# Load from file if it exists
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
return config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if isinstance(attr, SmolLM3DPOConfig):
return attr
# Return default configuration
return SmolLM3DPOConfig()
# Default DPO configuration instance
config = SmolLM3DPOConfig() |