File size: 6,467 Bytes
ebe598e
 
 
 
 
54ebacf
 
 
ebe598e
 
54ebacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebe598e
6c63876
 
 
 
ebe598e
54ebacf
 
 
 
 
ebe598e
 
54ebacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebe598e
 
54ebacf
 
ebe598e
 
54ebacf
 
 
 
 
 
 
ebe598e
 
54ebacf
 
ebe598e
 
54ebacf
 
 
ebe598e
 
54ebacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
SmolLM3 H100 Lightweight Training Configuration
Optimized for rapid training on H100 with 80K Hermes-FR samples
"""

import os
from dataclasses import dataclass
from typing import Optional
from config.train_smollm3 import SmolLM3Config

@dataclass
class SmolLM3ConfigH100Lightweight(SmolLM3Config):
    """Configuration for SmolLM3 fine-tuning on OpenHermes-FR dataset - H100 Lightweight"""
    
    # Model configuration - optimized for H100
    model_name: str = "HuggingFaceTB/SmolLM3-3B"
    max_seq_length: int = 8192  # Increased for better context understanding
    use_flash_attention: bool = True
    use_gradient_checkpointing: bool = True  # Enabled for memory efficiency
    
    # Training configuration - H100 optimized for rapid training
    batch_size: int = 16  # Larger batch size for H100
    gradient_accumulation_steps: int = 4  # Reduced for faster updates
    learning_rate: float = 8e-6  # Slightly higher for rapid convergence
    weight_decay: float = 0.01
    warmup_steps: int = 50  # Reduced warmup for rapid training
    max_iters: int = None  # Will be calculated based on epochs
    eval_interval: int = 50  # More frequent evaluation
    log_interval: int = 5  # More frequent logging
    save_interval: int = 200  # More frequent saving
    
    # Optimizer configuration - optimized for rapid training
    optimizer: str = "adamw_torch"
    beta1: float = 0.9
    beta2: float = 0.95
    eps: float = 1e-8
    
    # Scheduler configuration - faster learning
    scheduler: str = "cosine"
    min_lr: float = 2e-6  # Higher minimum LR
    
    # Mixed precision - Using fp16 for better compatibility
    # Note: bf16 can cause issues on some GPU setups, fp16 is more universally supported
    fp16: bool = False
    bf16: bool = True
    
    # Logging and saving - more frequent for rapid training
    save_steps: int = 200
    eval_steps: int = 50
    logging_steps: int = 5
    save_total_limit: Optional[int] = 2  # Keep fewer checkpoints
    
    # Evaluation
    eval_strategy: str = "steps"
    metric_for_best_model: str = "eval_loss"
    greater_is_better: bool = False
    load_best_model_at_end: bool = True
    
    # OpenHermes-FR Dataset configuration with sampling
    dataset_name: str = "legmlai/openhermes-fr"
    dataset_split: str = "train"
    input_field: str = "prompt"
    target_field: str = "completion"
    filter_bad_entries: bool = False
    bad_entry_field: str = "bad_entry"
    sample_size: int = 80000  # 80K samples for lightweight training
    sample_seed: int = 42  # For reproducibility
    
    # Data configuration (not used for HF datasets but kept for compatibility)
    data_dir: str = "my_dataset"
    train_file: str = "train.json"
    validation_file: Optional[str] = "validation.json"
    test_file: Optional[str] = None
    
    # Chat template configuration
    use_chat_template: bool = True
    chat_template_kwargs: dict = None
    
    # Trackio monitoring configuration
    enable_tracking: bool = True
    trackio_url: Optional[str] = None
    trackio_token: Optional[str] = None
    log_artifacts: bool = True
    log_metrics: bool = True
    log_config: bool = True
    experiment_name: Optional[str] = None
    
    # HF Datasets configuration
    hf_token: Optional[str] = None
    dataset_repo: Optional[str] = None
    
    # H100-specific optimizations
    dataloader_num_workers: int = 4  # Optimized for H100
    dataloader_pin_memory: bool = True
    dataloader_prefetch_factor: int = 2
    
    # Memory optimizations for rapid training
    max_grad_norm: float = 1.0
    group_by_length: bool = True  # Group similar length sequences
    
    # Training duration calculations
    # With 80k datapoints and effective batch size of 64:
    # Steps per epoch = 80,000 / 64 = 1,250 steps
    # For 1 epoch: 1,250 steps
    # For 2 epochs: 2,500 steps
    
    def __post_init__(self):
        if self.chat_template_kwargs is None:
            self.chat_template_kwargs = {
                "enable_thinking": False,
                "add_generation_prompt": True,
                "no_think_system_message": True
            }
        
        # Validate configuration
        if self.fp16 and self.bf16:
            raise ValueError("Cannot use both fp16 and bf16")
        
        if self.max_seq_length > 131072:  # 128k limit
            raise ValueError("max_seq_length cannot exceed 131072")
        
        # Calculate training statistics
        effective_batch_size = self.batch_size * self.gradient_accumulation_steps
        steps_per_epoch = self.sample_size // effective_batch_size  # For 80k dataset
        epochs_for_max_iters = self.max_iters / steps_per_epoch if self.max_iters else 1
        
        print(f"=== H100 Lightweight Training Configuration ===")
        print(f"Effective batch size: {effective_batch_size}")
        print(f"Steps per epoch: ~{steps_per_epoch}")
        print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
        print(f"Total training steps: {self.max_iters or 'auto'}")
        print(f"Learning rate: {self.learning_rate}")
        print(f"Mixed precision: {'fp16' if self.fp16 else 'bf16'}")
        print(f"Max sequence length: {self.max_seq_length}")
        print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
        print(f"Dataset sample size: {self.sample_size}")
        print("=" * 50)
        
        # Set default experiment name if not provided
        if self.experiment_name is None:
            self.experiment_name = "smollm3_h100_lightweight"

def get_config(config_path: str) -> SmolLM3ConfigH100Lightweight:
    """Load configuration from file or return default"""
    if os.path.exists(config_path):
        # Load from file if it exists
        import importlib.util
        spec = importlib.util.spec_from_file_location("config_module", config_path)
        config_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(config_module)
        
        if hasattr(config_module, 'config'):
            return config_module.config
        else:
            # Try to find a config class
            for attr_name in dir(config_module):
                attr = getattr(config_module, attr_name)
                if isinstance(attr, SmolLM3ConfigH100Lightweight):
                    return attr
    
    # Return default configuration
    return SmolLM3ConfigH100Lightweight()

# Default configuration instance
config = SmolLM3ConfigH100Lightweight()