Spaces:
Running
Running
File size: 7,698 Bytes
fcf2981 0fa6045 fcf2981 0fa6045 fcf2981 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
"""
GPT-OSS H100 Optimized Training Configuration
Based on OpenAI's GPT-OSS fine-tuning tutorial
Optimized for H100 GPU with maximum performance
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class GPTOSSH100OptimizedConfig:
"""H100-optimized configuration for GPT-OSS fine-tuning"""
# Trainer type selection
trainer_type: str = "sft" # "sft" or "dpo"
# Model configuration - GPT-OSS specific with H100 optimizations
model_name: str = "openai/gpt-oss-20b"
max_seq_length: int = 4096 # Increased for H100
use_flash_attention: bool = True
use_gradient_checkpointing: bool = True
# Training configuration - H100 optimized
batch_size: int = 8 # Larger batch size for H100
gradient_accumulation_steps: int = 2 # Reduced for faster updates
learning_rate: float = 3e-4 # Higher LR for H100
weight_decay: float = 0.01
warmup_steps: int = 50 # Reduced warmup for rapid training
max_iters: int = 2000 # More iterations for H100
eval_interval: int = 50 # More frequent evaluation
log_interval: int = 5 # More frequent logging
save_interval: int = 200 # More frequent saving
# Optimizer configuration - H100 optimized
optimizer: str = "adamw_torch"
beta1: float = 0.9
beta2: float = 0.95
eps: float = 1e-8
# Scheduler configuration - faster learning
scheduler: str = "cosine_with_min_lr"
min_lr: float = 3e-5 # Higher min LR for H100
lr_scheduler_kwargs: dict = None
# Mixed precision - H100 optimized
fp16: bool = False # Use bf16 for H100
bf16: bool = True
# DDP configuration
ddp_backend: str = "nccl"
ddp_find_unused_parameters: bool = False
# Logging and saving - optimized for rapid training
save_steps: int = 200
eval_steps: int = 50
logging_steps: int = 5
save_total_limit: Optional[int] = 2 # Keep fewer checkpoints
# Evaluation
eval_strategy: str = "steps"
metric_for_best_model: str = "eval_loss"
greater_is_better: bool = False
load_best_model_at_end: bool = True
eval_accumulation_steps: Optional[int] = None
eval_ratio: float = 0.01
test_ratio: float = 0.01
# Data configuration
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
dataset_split: str = "train"
input_field: str = "messages" # GPT-OSS uses messages format
target_field: str = None # Not used for messages format
filter_bad_entries: bool = False
bad_entry_field: str = "bad_entry"
# Chat template configuration - GPT-OSS specific
use_chat_template: bool = True
chat_template_kwargs: dict = None
# Trackio monitoring configuration
enable_tracking: bool = True
trackio_url: Optional[str] = None
trackio_token: Optional[str] = None
log_artifacts: bool = True
log_metrics: bool = True
log_config: bool = True
experiment_name: Optional[str] = None
# HF Datasets configuration
hf_token: Optional[str] = None
dataset_repo: Optional[str] = None
# GPT-OSS specific configurations
# LoRA configuration for GPT-OSS - H100 optimized
use_lora: bool = True
lora_config: dict = None
# Quantization for GPT-OSS (MXFP4) - H100 optimized
use_quantization: bool = True
quantization_config: dict = None
# GPT-OSS specific model kwargs - H100 optimized
model_kwargs: dict = None
# H100-specific optimizations
dataloader_num_workers: int = 8 # More workers for H100
dataloader_pin_memory: bool = True
dataloader_prefetch_factor: int = 4 # Increased prefetch
tf32: Optional[bool] = None
chosen_field: Optional[str] = None
rejected_field: Optional[str] = None
dpo_beta: float = 0.1
# Memory optimizations for H100
max_grad_norm: float = 1.0
group_by_length: bool = True # Group similar length sequences
def __post_init__(self):
if self.chat_template_kwargs is None:
self.chat_template_kwargs = {
"add_generation_prompt": True,
"tokenize": False # GPT-OSS specific
}
if self.lr_scheduler_kwargs is None:
self.lr_scheduler_kwargs = {
"min_lr_rate": 0.1
}
if self.lora_config is None:
self.lora_config = {
"r": 16, # Increased for H100
"lora_alpha": 32, # Increased for H100
"target_modules": "all-linear",
"target_parameters": [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
]
}
if self.quantization_config is None:
self.quantization_config = {
"dequantize": True
}
if self.model_kwargs is None:
self.model_kwargs = {
"attn_implementation": "eager",
"torch_dtype": "auto",
"use_cache": False,
"device_map": "auto"
}
# Validate configuration
if self.fp16 and self.bf16:
raise ValueError("Cannot use both fp16 and bf16")
if self.max_seq_length > 131072: # 128k limit
raise ValueError("max_seq_length cannot exceed 131072")
# Calculate training statistics for H100
effective_batch_size = self.batch_size * self.gradient_accumulation_steps
steps_per_epoch = 1000 // effective_batch_size # Approximate for Multilingual-Thinking
epochs_for_max_iters = self.max_iters / steps_per_epoch
print(f"=== GPT-OSS H100 Optimized Configuration ===")
print(f"Effective batch size: {effective_batch_size}")
print(f"Steps per epoch: ~{steps_per_epoch}")
print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
print(f"Total training steps: {self.max_iters}")
print(f"Learning rate: {self.learning_rate}")
print(f"Mixed precision: {'bf16' if self.bf16 else 'fp16'}")
print(f"Max sequence length: {self.max_seq_length}")
print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
print(f"LoRA rank: {self.lora_config['r']}")
print(f"Data loader workers: {self.dataloader_num_workers}")
print("=" * 50)
# Set default experiment name if not provided
if self.experiment_name is None:
self.experiment_name = "gpt_oss_h100_optimized"
def get_config(config_path: str) -> GPTOSSH100OptimizedConfig:
"""Load configuration from file or return default"""
if os.path.exists(config_path):
# Load from file if it exists
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
return config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if isinstance(attr, GPTOSSH100OptimizedConfig):
return attr
# Return default configuration
return GPTOSSH100OptimizedConfig()
# Default configuration instance
config = GPTOSSH100OptimizedConfig() |