File size: 3,384 Bytes
d8dd7a1
 
5fe83da
d8dd7a1
 
5fe83da
 
 
d8dd7a1
 
5fe83da
 
 
d8dd7a1
40fd629
 
 
5fe83da
 
 
 
d8dd7a1
5fe83da
 
 
 
d8dd7a1
5fe83da
 
 
d8dd7a1
5fe83da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
SmolLM3 DPO Training Configuration
Based on nanoGPT structure but adapted for SmolLM3 DPO training
"""

import os
from dataclasses import dataclass
from typing import Optional
from config.train_smollm3 import SmolLM3Config

@dataclass
class SmolLM3DPOConfig(SmolLM3Config):
    """Configuration for SmolLM3 DPO fine-tuning"""
    
    # Trainer type selection
    trainer_type: str = "dpo"  # Override default to use DPO trainer
    
    # DPO-specific configuration
    beta: float = 0.1
    max_prompt_length: int = 2048
    max_length: int = 4096
    
    # DPO training configuration
    dpo_beta: float = 0.1
    dpo_loss_type: str = "sigmoid"  # "sigmoid" or "hinge"
    dpo_alpha: float = 0.5
    
    # Reference model configuration
    ref_model_name: Optional[str] = None  # If None, will use the same as model_name
    ref_model_peft_config: Optional[dict] = None
    
    # Preference dataset configuration
    preference_dataset_format: str = "dpo"  # "dpo", "rlhf", "custom"
    preference_dataset_text_field: str = "text"
    preference_dataset_prompt_field: str = "prompt"
    preference_dataset_chosen_field: str = "chosen"
    preference_dataset_rejected_field: str = "rejected"
    
    # DPO training arguments
    dpo_gradient_checkpointing: bool = True
    dpo_gradient_checkpointing_kwargs: dict = None
    dpo_precompute_ref_log_probs: bool = False
    dpo_peft_config: Optional[dict] = None
    
    def __post_init__(self):
        super().__post_init__()
        
        # Set default values for DPO-specific settings
        if self.ref_model_name is None:
            self.ref_model_name = self.model_name
        
        if self.dpo_gradient_checkpointing_kwargs is None:
            self.dpo_gradient_checkpointing_kwargs = {
                "use_reentrant": False
            }
        
        if self.dpo_peft_config is None:
            self.dpo_peft_config = {
                "r": 16,
                "lora_alpha": 32,
                "lora_dropout": 0.1,
                "bias": "none",
                "task_type": "CAUSAL_LM"
            }
        
        # Validate DPO configuration
        if self.beta <= 0:
            raise ValueError("beta must be positive")
        
        if self.max_prompt_length > self.max_seq_length:
            raise ValueError("max_prompt_length cannot exceed max_seq_length")
        
        if self.max_length > self.max_seq_length:
            raise ValueError("max_length cannot exceed max_seq_length")

def get_dpo_config(config_path: str) -> SmolLM3DPOConfig:
    """Load DPO configuration from file or return default"""
    if os.path.exists(config_path):
        # Load from file if it exists
        import importlib.util
        spec = importlib.util.spec_from_file_location("config_module", config_path)
        config_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(config_module)
        
        if hasattr(config_module, 'config'):
            return config_module.config
        else:
            # Try to find a config class
            for attr_name in dir(config_module):
                attr = getattr(config_module, attr_name)
                if isinstance(attr, SmolLM3DPOConfig):
                    return attr
    
    # Return default configuration
    return SmolLM3DPOConfig()

# Default DPO configuration instance
config = SmolLM3DPOConfig()