Spaces:
Running
Running
File size: 6,997 Bytes
d60ab6c d9f7e1b d60ab6c d9f7e1b d60ab6c ebe598e d60ab6c 32fca7d d60ab6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
SmolLM3 Training Configuration for OpenHermes-FR Dataset - A100 Balanced
Optimized for good GPU utilization without running out of memory
"""
import os
from dataclasses import dataclass
from typing import Optional
from config.train_smollm3 import SmolLM3Config
@dataclass
class SmolLM3ConfigOpenHermesFRBalanced(SmolLM3Config):
"""Configuration for SmolLM3 fine-tuning with balanced A100 performance"""
# Model configuration - balanced for A100
model_name: str = "HuggingFaceTB/SmolLM3-3B"
max_seq_length: int = 12288 # Long context in SmolLM3
use_flash_attention: bool = True
use_gradient_checkpointing: bool = False # Disabled for A100 efficiency
# Training configuration - Balanced GPU utilization
batch_size: int = 8 # Moderate increase
gradient_accumulation_steps: int = 16 # Effective batch size = 8 * 16 = 128
learning_rate: float = 3.5e-6 # Slightly higher for larger effective batch
weight_decay: float = 0.01
warmup_steps: int = 1200 # More warmup for larger batch
max_iters: int = 18000 # More iterations for faster convergence
eval_interval: int = 1000 # Less frequent evaluation
log_interval: int = 25 # Less frequent logging
save_interval: int = 2000 # Less frequent saving
# Optimizer configuration - optimized for large batches
optimizer: str = "adamw_torch"
beta1: float = 0.9
beta2: float = 0.999 # Higher beta2 for stability
eps: float = 1e-8
# Scheduler configuration - faster training
scheduler: str = "cosine"
min_lr: float = 3.5e-7 # Lower min LR
# Mixed precision - A100 optimized
fp16: bool = False # Use bf16 for A100
bf16: bool = True # Better for A100
# DDP configuration
ddp_backend: str = "nccl"
ddp_find_unused_parameters: bool = False
# Logging and saving - optimized for fast training
save_steps: int = 2000
eval_steps: int = 1000
logging_steps: int = 25
save_total_limit: Optional[int] = 5 # Keep fewer checkpoints
# Evaluation
eval_strategy: str = "steps"
metric_for_best_model: str = "eval_loss"
greater_is_better: bool = False
load_best_model_at_end: bool = True
# OpenHermes-FR Dataset configuration
dataset_name: str = "legmlai/openhermes-fr"
dataset_split: str = "train"
input_field: str = "prompt"
target_field: str = "accepted_completion"
filter_bad_entries: bool = True
bad_entry_field: str = "bad_entry"
# Data configuration (not used for HF datasets but kept for compatibility)
data_dir: str = None
train_file: str = None
validation_file: Optional[str] = None
test_file: Optional[str] = None
# Chat template configuration
use_chat_template: bool = True
chat_template_kwargs: dict = None
# SFTTrainer-specific optimizations
packing: bool = False # Disable packing for better stability with long sequences
max_prompt_length: int = 12288 # Increased to handle longer prompts
max_completion_length: int = 8192 # long completion length
truncation: bool = True # Enable truncation for long sequences
# Trackio monitoring configuration
enable_tracking: bool = True
trackio_url: Optional[str] = None
trackio_token: Optional[str] = None
log_artifacts: bool = True
log_metrics: bool = True
log_config: bool = True
experiment_name: Optional[str] = None
# HF Datasets configuration
hf_token: Optional[str] = None
dataset_repo: Optional[str] = None
# Additional A100 optimizations for balanced performance
dataloader_num_workers: int = 10 # More workers for faster data loading
dataloader_pin_memory: bool = True
dataloader_prefetch_factor: int = 3 # Increased prefetch
# Memory optimizations
max_grad_norm: float = 1.0 # Gradient clipping
group_by_length: bool = True # Group similar length sequences
# Training duration calculations
# With 800k datapoints and effective batch size of 128:
# Steps per epoch = 800,000 / 128 = 6,250 steps
# For 3 passes: 6,250 * 3 = 18,750 steps
# Current max_iters = 18,000 (about 2.9 passes)
def __post_init__(self):
if self.chat_template_kwargs is None:
self.chat_template_kwargs = {
"add_generation_prompt": True,
"no_think_system_message": True # Set to True to add /no_think tag
}
# Validate configuration
if self.fp16 and self.bf16:
raise ValueError("Cannot use both fp16 and bf16")
if self.max_seq_length > 131072: # 128k limit
raise ValueError("max_seq_length cannot exceed 131072")
# Calculate training statistics
effective_batch_size = self.batch_size * self.gradient_accumulation_steps
steps_per_epoch = 800000 // effective_batch_size # Approximate for 800k dataset
epochs_for_max_iters = self.max_iters / steps_per_epoch
print(f"=== A100 Balanced Configuration ===")
print(f"Effective batch size: {effective_batch_size}")
print(f"Steps per epoch: ~{steps_per_epoch}")
print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
print(f"Total training steps: {self.max_iters}")
print(f"Learning rate: {self.learning_rate}")
print(f"Mixed precision: {'bf16' if self.bf16 else 'fp16'}")
print(f"Max sequence length: {self.max_seq_length}")
print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
print(f"Batch size: {self.batch_size}")
print(f"Gradient accumulation: {self.gradient_accumulation_steps}")
print(f"Data loader workers: {self.dataloader_num_workers}")
print(f"Prefetch factor: {self.dataloader_prefetch_factor}")
print("=" * 50)
# Set default experiment name if not provided
if self.experiment_name is None:
self.experiment_name = "smollm3_openhermes_fr_balanced"
def get_config(config_path: str) -> SmolLM3ConfigOpenHermesFRBalanced:
"""Load configuration from file or return default"""
if os.path.exists(config_path):
# Load from file if it exists
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
return config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if isinstance(attr, SmolLM3ConfigOpenHermesFRBalanced):
return attr
# Return default configuration
return SmolLM3ConfigOpenHermesFRBalanced()
# Default configuration instance
config = SmolLM3ConfigOpenHermesFRBalanced() |