SmolFactory / config /train_gpt_oss_custom.py
Tonic's picture
adds packing argument to gpt oss custom config
401f18e
"""
GPT-OSS Custom Training Configuration
Based on OpenAI's GPT-OSS fine-tuning tutorial
Fully customizable configuration for any dataset format
Supports specialized datasets like:
- legmlai/openhermes-fr (French instruction dataset)
- HuggingFaceH4/Multilingual-Thinking
- Custom prompt/completion formats
"""
import os
from dataclasses import dataclass
from typing import Optional, Dict, List, Union
@dataclass
class GPTOSSEnhancedCustomConfig:
"""Enhanced custom configuration for GPT-OSS fine-tuning with maximum flexibility"""
# ============================================================================
# CORE MODEL CONFIGURATION
# ============================================================================
trainer_type: str = "sft" # "sft" or "dpo"
model_name: str = "openai/gpt-oss-20b"
max_seq_length: int = 2048 # Customizable: 512, 1024, 2048, 4096, 8192
use_flash_attention: bool = True
use_gradient_checkpointing: bool = True
# ============================================================================
# TRAINING HYPERPARAMETERS - FULLY CUSTOMIZABLE
# ============================================================================
# Batch Configuration
batch_size: int = 4 # Per-device batch size (1-32 depending on GPU memory)
gradient_accumulation_steps: int = 4 # Effective batch = batch_size * accumulation * num_gpus
eval_batch_size: Optional[int] = None # If None, uses batch_size
# Learning Rate Configuration
learning_rate: float = 2e-4 # Main learning rate (1e-5 to 5e-4 typical range)
min_lr: float = 2e-5 # Minimum learning rate for scheduler
warmup_ratio: float = 0.03 # Fraction of steps for warmup (0.01-0.1)
warmup_steps: Optional[int] = None # If set, overrides warmup_ratio
# Training Duration
num_train_epochs: float = 1.0 # Number of epochs (0.5, 1.0, 2.0, 3.0)
max_steps: Optional[int] = None # If set, overrides num_train_epochs
max_iters: Optional[int] = None # Legacy compatibility
# Regularization
weight_decay: float = 0.01 # L2 regularization (0.0-0.1)
max_grad_norm: float = 1.0 # Gradient clipping (0.5-2.0)
# ============================================================================
# OPTIMIZER CONFIGURATION
# ============================================================================
optimizer: str = "adamw_torch" # "adamw_torch", "adamw_hf", "sgd"
beta1: float = 0.9 # Adam beta1 parameter
beta2: float = 0.95 # Adam beta2 parameter (0.95-0.999)
eps: float = 1e-8 # Adam epsilon
# ============================================================================
# SCHEDULER CONFIGURATION
# ============================================================================
scheduler: str = "cosine" # Default to broadly compatible scheduler; TRL special is opt-in
lr_scheduler_kwargs: Optional[Dict] = None
# ============================================================================
# MIXED PRECISION & DISTRIBUTED TRAINING
# ============================================================================
fp16: bool = False # Use FP16 (not recommended for GPT-OSS)
bf16: bool = True # Use BF16 (recommended for GPT-OSS)
tf32: Optional[bool] = None # Use TF32 on A100/H100
ddp_backend: str = "nccl"
ddp_find_unused_parameters: bool = False
# ============================================================================
# LOGGING, EVALUATION & CHECKPOINTING
# ============================================================================
# Logging
logging_steps: int = 10 # Log every N steps
log_level: str = "info" # "debug", "info", "warning", "error"
# Evaluation
eval_strategy: str = "steps" # "no", "steps", "epoch"
eval_steps: int = 100 # Evaluate every N steps
eval_delay: float = 0 # Delay evaluation for N steps/epochs
eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
# Automatic split ratios when only a single training split is provided
eval_ratio: float = 0.01 # Fraction of data for validation (0.0-0.5 typical)
test_ratio: float = 0.01 # Fraction of data for test (0.0-0.5 typical)
# Checkpointing
save_strategy: str = "steps" # "no", "steps", "epoch"
save_steps: int = 500 # Save checkpoint every N steps
save_total_limit: Optional[int] = 3 # Keep only N best checkpoints
save_only_model: bool = False # Save only model weights
# TRL packing (token packing of multiple samples into a single sequence)
# Some configs (e.g., openhermes_fr_memory_optimized) set this to True
packing: bool = False
# Model Selection
metric_for_best_model: str = "eval_loss"
greater_is_better: bool = False
load_best_model_at_end: bool = True
# ============================================================================
# DATASET CONFIGURATION - ENHANCED FOR CUSTOM FORMATS
# ============================================================================
# Dataset Source
dataset_name: str = "legmlai/openhermes-fr" # Default to French OpenHermes
dataset_split: str = "train" # Dataset split to use
dataset_config: Optional[str] = None # Dataset configuration name
# Field Mapping - Customize for your dataset format
input_field: str = "prompt" # Field containing the input/prompt
target_field: str = "accepted_completion" # Field containing the target/completion
# Optional global conversational context
system_message: Optional[str] = None
developer_message: Optional[str] = None
# OpenHermes-FR specific fields
filter_bad_entries: bool = True # Filter entries marked as bad
bad_entry_field: str = "bad_entry" # Field indicating bad entries
bad_prompt_field: str = "bad_prompt_detected" # Field for bad prompts
bad_response_field: str = "bad_response_detected" # Field for bad responses
# Data Processing Options
concatenate_fields: bool = True # Combine input and target fields for training
field_separator: str = "\n\n### Response:\n" # Separator between input and target
add_eos_token: bool = True # Add EOS token at the end
# Dataset Filtering & Sampling
max_samples: Optional[int] = None # Limit dataset size (e.g., 100000 for testing)
min_length: int = 10 # Minimum sequence length
max_length: Optional[int] = None # Maximum sequence length (None = use max_seq_length)
# Custom Dataset Formats Support
dataset_format: str = "openhermes_fr" # "openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"
# Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT) mapping
question_field: str = "Question"
reasoning_field: str = "Complex_CoT"
response_field: str = "Response"
reason_prefix: str = "Reasoning: "
answer_prefix: str = "Final Answer: "
# GPT-OSS Harmony Format Configuration
use_harmony_format: bool = True # Enable GPT-OSS harmony format
use_chat_template: bool = False # Set to True for messages format
chat_template_kwargs: Optional[Dict] = None
# ============================================================================
# TRACKIO MONITORING CONFIGURATION
# ============================================================================
enable_tracking: bool = True
trackio_url: Optional[str] = None
trackio_token: Optional[str] = None
log_artifacts: bool = True
log_metrics: bool = True
log_config: bool = True
experiment_name: Optional[str] = None
# ============================================================================
# HUGGING FACE INTEGRATION
# ============================================================================
hf_token: Optional[str] = None
dataset_repo: Optional[str] = None
push_to_hub: bool = False # Push model to HF Hub after training
hub_model_id: Optional[str] = None # HF Hub model ID
hub_private_repo: bool = False # Make HF repo private
# ============================================================================
# GPT-OSS SPECIFIC CONFIGURATIONS
# ============================================================================
# LoRA Configuration
use_lora: bool = True
lora_config: Optional[Dict] = None
# Quantization Configuration
use_quantization: bool = True
quantization_config: Optional[Dict] = None
# Model Loading Configuration
model_kwargs: Optional[Dict] = None
# Generation Configuration (for evaluation/testing)
generation_config: Optional[Dict] = None
# Preference-training (DPO) configuration
chosen_field: Optional[str] = None # Field name for preferred response (for DPO datasets)
rejected_field: Optional[str] = None # Field name for rejected response (for DPO datasets)
dpo_beta: float = 0.1 # DPO beta parameter
# ============================================================================
# MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
# ============================================================================
# Language Support (for multilingual datasets)
primary_language: str = "fr" # Primary language code
reasoning_languages: Optional[List[str]] = None # Supported languages for reasoning
# Domain-specific settings
domain_focus: Optional[str] = None # "reasoning", "conversation", "instruction", "general"
# ============================================================================
# PERFORMANCE & MEMORY OPTIMIZATION
# ============================================================================
# Data Loading
dataloader_num_workers: int = 4 # Number of data loading workers
dataloader_pin_memory: bool = True # Pin memory for faster GPU transfer
dataloader_prefetch_factor: int = 2 # Prefetch factor for data loading
dataset_num_proc: Optional[int] = None # Parallel CPU processes for datasets map/filter ops
# Memory Management
max_memory_per_gpu: Optional[str] = None # e.g., "80GB", "40GB"
low_cpu_mem_usage: bool = True # Use low CPU memory loading
# Performance Optimizations
group_by_length: bool = True # Group sequences by length
length_column_name: str = "length" # Column name for sequence lengths
remove_unused_columns: bool = True # Remove unused dataset columns
def __post_init__(self):
"""Initialize default values and validate configuration"""
# ============================================================================
# LORA CONFIGURATION DEFAULTS
# ============================================================================
if self.lora_config is None:
self.lora_config = {
"r": 16, # Rank (4, 8, 16, 32, 64) - higher = more parameters
"lora_alpha": 32, # Scaling factor (usually 2*r)
"target_modules": "all-linear", # Apply LoRA to all linear layers
"target_parameters": [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
],
"bias": "none", # "none", "all", "lora_only"
"task_type": "CAUSAL_LM",
"lora_dropout": 0.05, # LoRA dropout rate
}
# ============================================================================
# QUANTIZATION CONFIGURATION DEFAULTS
# ============================================================================
if self.quantization_config is None:
self.quantization_config = {
"dequantize": True, # Use Mxfp4Config as per GPT-OSS tutorial
"load_in_4bit": False, # Set to True for extreme memory optimization
"bnb_4bit_compute_dtype": "bfloat16", # For 4-bit quantization
"bnb_4bit_use_double_quant": True, # Double quantization
"bnb_4bit_quant_type": "nf4" # Quantization type
}
# ============================================================================
# MODEL LOADING CONFIGURATION DEFAULTS
# ============================================================================
if self.model_kwargs is None:
self.model_kwargs = {
"attn_implementation": "eager", # "eager", "flash_attention_2"
"torch_dtype": "auto", # "auto", "bfloat16", "float16"
"use_cache": False, # Disable KV cache for training
"device_map": "auto", # Automatic device mapping
"low_cpu_mem_usage": self.low_cpu_mem_usage,
}
# Add memory constraints if specified
if self.max_memory_per_gpu:
self.model_kwargs["max_memory"] = {0: self.max_memory_per_gpu}
# ============================================================================
# GENERATION CONFIGURATION DEFAULTS
# ============================================================================
if self.generation_config is None:
self.generation_config = {
"max_new_tokens": 512, # Maximum tokens to generate
"do_sample": True, # Use sampling
"temperature": 0.7, # Sampling temperature
"top_p": 0.9, # Nucleus sampling
"top_k": 50, # Top-k sampling
"repetition_penalty": 1.1, # Repetition penalty
"pad_token_id": None, # Will be set from tokenizer
"eos_token_id": None, # Will be set from tokenizer
}
# ============================================================================
# LANGUAGE CONFIGURATION DEFAULTS
# ============================================================================
if self.reasoning_languages is None:
if self.primary_language == "fr":
self.reasoning_languages = [
"French", "English", "Spanish", "Italian", "German"
]
else:
self.reasoning_languages = [
"English", "Spanish", "French", "Italian", "German",
"Chinese", "Hindi", "Japanese", "Korean", "Arabic"
]
# ============================================================================
# SCHEDULER CONFIGURATION DEFAULTS
# ============================================================================
if self.lr_scheduler_kwargs is None:
# Leave empty; training script will add TRL-specific keys only when needed
self.lr_scheduler_kwargs = {}
# ============================================================================
# CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
# ============================================================================
if self.chat_template_kwargs is None:
self.chat_template_kwargs = {
"add_generation_prompt": True,
"tokenize": False,
"auto_insert_role": True,
# GPT-OSS Harmony Format specific settings
"reasoning_effort": "medium", # low, medium, high
"model_identity": "You are GPT-Tonic, a large language model trained by TonicAI.",
"builtin_tools": [], # Can include "browser" and/or "python"
}
# ============================================================================
# VALIDATION AND COMPUTED VALUES
# ============================================================================
# Compute effective batch size
effective_batch_size = self.batch_size * self.gradient_accumulation_steps
# Set warmup steps if not provided
if self.warmup_steps is None and self.max_steps:
self.warmup_steps = int(self.max_steps * self.warmup_ratio)
# Set max_length for dataset filtering
if self.max_length is None:
self.max_length = self.max_seq_length
# Validate configuration
self._validate_config()
# Print comprehensive configuration summary
self._print_config_summary(effective_batch_size)
def _validate_config(self):
"""Validate configuration parameters"""
# Validate batch configuration
if self.batch_size < 1:
raise ValueError("batch_size must be >= 1")
if self.gradient_accumulation_steps < 1:
raise ValueError("gradient_accumulation_steps must be >= 1")
# Validate learning rate
if self.learning_rate <= 0:
raise ValueError("learning_rate must be > 0")
if self.min_lr >= self.learning_rate:
raise ValueError("min_lr must be < learning_rate")
# Validate sequence length
if self.max_seq_length < 1:
raise ValueError("max_seq_length must be >= 1")
# Validate dataset format
valid_formats = ["openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"]
if self.dataset_format not in valid_formats:
raise ValueError(f"dataset_format must be one of {valid_formats}")
def _print_config_summary(self, effective_batch_size):
"""Print detailed configuration summary"""
print("\n" + "="*80)
print("πŸš€ GPT-OSS ENHANCED CUSTOM CONFIGURATION")
print("="*80)
print(f"πŸ“Š Model & Training:")
print(f" β€’ Model: {self.model_name}")
print(f" β€’ Dataset: {self.dataset_name} ({self.dataset_format})")
print(f" β€’ Primary Language: {self.primary_language}")
print(f" β€’ Sequence Length: {self.max_seq_length}")
print(f" β€’ Epochs: {self.num_train_epochs}")
print(f"\nπŸ”„ Batch Configuration:")
print(f" β€’ Per-device Batch Size: {self.batch_size}")
print(f" β€’ Gradient Accumulation: {self.gradient_accumulation_steps}")
print(f" β€’ Effective Batch Size: {effective_batch_size}")
print(f"\nπŸ“ˆ Learning Configuration:")
print(f" β€’ Learning Rate: {self.learning_rate}")
print(f" β€’ Min Learning Rate: {self.min_lr}")
print(f" β€’ Weight Decay: {self.weight_decay}")
print(f" β€’ Warmup Ratio: {self.warmup_ratio}")
print(f"\nπŸŽ›οΈ LoRA Configuration:")
print(f" β€’ Rank: {self.lora_config['r']}")
print(f" β€’ Alpha: {self.lora_config['lora_alpha']}")
print(f" β€’ Target Modules: {self.lora_config['target_modules']}")
print(f"\nπŸ“ Dataset Configuration:")
print(f" β€’ Input Field: {self.input_field}")
print(f" β€’ Target Field: {self.target_field}")
print(f" β€’ Filter Bad Entries: {self.filter_bad_entries}")
print(f" β€’ Max Samples: {self.max_samples or 'All'}")
if self.system_message or self.developer_message:
print(" β€’ Context messages set:")
if self.system_message:
print(" - system message: provided")
if self.developer_message:
print(" - developer message: provided")
print(f"\nπŸ’Ύ Memory & Performance:")
print(f" β€’ Mixed Precision: {'BF16' if self.bf16 else 'FP32'}")
print(f" β€’ Gradient Checkpointing: {self.use_gradient_checkpointing}")
print(f" β€’ Data Workers: {self.dataloader_num_workers}")
print(f" β€’ Group by Length: {self.group_by_length}")
print("="*80 + "\n")
# Create the config instance with OpenHermes-FR optimized defaults
config = GPTOSSEnhancedCustomConfig()