File size: 20,338 Bytes
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d455d12
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa6045
 
 
59e57ff
 
 
 
 
 
401f18e
 
 
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ded6bb
 
 
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ded6bb
 
 
 
 
 
 
 
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa6045
 
 
 
 
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f39f1
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d455d12
 
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ded6bb
59e57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ded6bb
 
 
 
 
 
59e57ff
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
"""
GPT-OSS Custom Training Configuration
Based on OpenAI's GPT-OSS fine-tuning tutorial
Fully customizable configuration for any dataset format

Supports specialized datasets like:
- legmlai/openhermes-fr (French instruction dataset)
- HuggingFaceH4/Multilingual-Thinking
- Custom prompt/completion formats
"""
import os
from dataclasses import dataclass
from typing import Optional, Dict, List, Union

@dataclass
class GPTOSSEnhancedCustomConfig:
    """Enhanced custom configuration for GPT-OSS fine-tuning with maximum flexibility"""
    
    # ============================================================================
    # CORE MODEL CONFIGURATION
    # ============================================================================
    trainer_type: str = "sft"  # "sft" or "dpo"
    model_name: str = "openai/gpt-oss-20b"
    max_seq_length: int = 2048  # Customizable: 512, 1024, 2048, 4096, 8192
    use_flash_attention: bool = True
    use_gradient_checkpointing: bool = True
    
    # ============================================================================
    # TRAINING HYPERPARAMETERS - FULLY CUSTOMIZABLE
    # ============================================================================
    # Batch Configuration
    batch_size: int = 4  # Per-device batch size (1-32 depending on GPU memory)
    gradient_accumulation_steps: int = 4  # Effective batch = batch_size * accumulation * num_gpus
    eval_batch_size: Optional[int] = None  # If None, uses batch_size
    
    # Learning Rate Configuration
    learning_rate: float = 2e-4  # Main learning rate (1e-5 to 5e-4 typical range)
    min_lr: float = 2e-5  # Minimum learning rate for scheduler
    warmup_ratio: float = 0.03  # Fraction of steps for warmup (0.01-0.1)
    warmup_steps: Optional[int] = None  # If set, overrides warmup_ratio
    
    # Training Duration
    num_train_epochs: float = 1.0  # Number of epochs (0.5, 1.0, 2.0, 3.0)
    max_steps: Optional[int] = None  # If set, overrides num_train_epochs
    max_iters: Optional[int] = None  # Legacy compatibility
    
    # Regularization
    weight_decay: float = 0.01  # L2 regularization (0.0-0.1)
    max_grad_norm: float = 1.0  # Gradient clipping (0.5-2.0)
    
    # ============================================================================
    # OPTIMIZER CONFIGURATION
    # ============================================================================
    optimizer: str = "adamw_torch"  # "adamw_torch", "adamw_hf", "sgd"
    beta1: float = 0.9  # Adam beta1 parameter
    beta2: float = 0.95  # Adam beta2 parameter (0.95-0.999)
    eps: float = 1e-8  # Adam epsilon
    
    # ============================================================================
    # SCHEDULER CONFIGURATION
    # ============================================================================
    scheduler: str = "cosine"  # Default to broadly compatible scheduler; TRL special is opt-in
    lr_scheduler_kwargs: Optional[Dict] = None
    
    # ============================================================================
    # MIXED PRECISION & DISTRIBUTED TRAINING
    # ============================================================================
    fp16: bool = False  # Use FP16 (not recommended for GPT-OSS)
    bf16: bool = True  # Use BF16 (recommended for GPT-OSS)
    tf32: Optional[bool] = None  # Use TF32 on A100/H100
    ddp_backend: str = "nccl"
    ddp_find_unused_parameters: bool = False
    
    # ============================================================================
    # LOGGING, EVALUATION & CHECKPOINTING
    # ============================================================================
    # Logging
    logging_steps: int = 10  # Log every N steps
    log_level: str = "info"  # "debug", "info", "warning", "error"
    
    # Evaluation
    eval_strategy: str = "steps"  # "no", "steps", "epoch"
    eval_steps: int = 100  # Evaluate every N steps
    eval_delay: float = 0  # Delay evaluation for N steps/epochs
    eval_accumulation_steps: Optional[int] = None  # Accumulate eval outputs
    # Automatic split ratios when only a single training split is provided
    eval_ratio: float = 0.01  # Fraction of data for validation (0.0-0.5 typical)
    test_ratio: float = 0.01  # Fraction of data for test (0.0-0.5 typical)
    
    # Checkpointing
    save_strategy: str = "steps"  # "no", "steps", "epoch"
    save_steps: int = 500  # Save checkpoint every N steps
    save_total_limit: Optional[int] = 3  # Keep only N best checkpoints
    save_only_model: bool = False  # Save only model weights
    # TRL packing (token packing of multiple samples into a single sequence)
    # Some configs (e.g., openhermes_fr_memory_optimized) set this to True
    packing: bool = False
    
    # Model Selection
    metric_for_best_model: str = "eval_loss"
    greater_is_better: bool = False
    load_best_model_at_end: bool = True
    
    # ============================================================================
    # DATASET CONFIGURATION - ENHANCED FOR CUSTOM FORMATS
    # ============================================================================
    # Dataset Source
    dataset_name: str = "legmlai/openhermes-fr"  # Default to French OpenHermes
    dataset_split: str = "train"  # Dataset split to use
    dataset_config: Optional[str] = None  # Dataset configuration name
    
    # Field Mapping - Customize for your dataset format
    input_field: str = "prompt"  # Field containing the input/prompt
    target_field: str = "accepted_completion"  # Field containing the target/completion
    # Optional global conversational context
    system_message: Optional[str] = None
    developer_message: Optional[str] = None
    
    # OpenHermes-FR specific fields
    filter_bad_entries: bool = True  # Filter entries marked as bad
    bad_entry_field: str = "bad_entry"  # Field indicating bad entries
    bad_prompt_field: str = "bad_prompt_detected"  # Field for bad prompts
    bad_response_field: str = "bad_response_detected"  # Field for bad responses
    
    # Data Processing Options
    concatenate_fields: bool = True  # Combine input and target fields for training
    field_separator: str = "\n\n### Response:\n"  # Separator between input and target
    add_eos_token: bool = True  # Add EOS token at the end
    
    # Dataset Filtering & Sampling
    max_samples: Optional[int] = None  # Limit dataset size (e.g., 100000 for testing)
    min_length: int = 10  # Minimum sequence length
    max_length: Optional[int] = None  # Maximum sequence length (None = use max_seq_length)
    
    # Custom Dataset Formats Support
    dataset_format: str = "openhermes_fr"  # "openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"

    # Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT) mapping
    question_field: str = "Question"
    reasoning_field: str = "Complex_CoT"
    response_field: str = "Response"
    reason_prefix: str = "Reasoning: "
    answer_prefix: str = "Final Answer: "
    
    # GPT-OSS Harmony Format Configuration
    use_harmony_format: bool = True  # Enable GPT-OSS harmony format
    use_chat_template: bool = False  # Set to True for messages format
    chat_template_kwargs: Optional[Dict] = None
    
    # ============================================================================
    # TRACKIO MONITORING CONFIGURATION
    # ============================================================================
    enable_tracking: bool = True
    trackio_url: Optional[str] = None
    trackio_token: Optional[str] = None
    log_artifacts: bool = True
    log_metrics: bool = True
    log_config: bool = True
    experiment_name: Optional[str] = None
    
    # ============================================================================
    # HUGGING FACE INTEGRATION
    # ============================================================================
    hf_token: Optional[str] = None
    dataset_repo: Optional[str] = None
    push_to_hub: bool = False  # Push model to HF Hub after training
    hub_model_id: Optional[str] = None  # HF Hub model ID
    hub_private_repo: bool = False  # Make HF repo private
    
    # ============================================================================
    # GPT-OSS SPECIFIC CONFIGURATIONS
    # ============================================================================
    # LoRA Configuration
    use_lora: bool = True
    lora_config: Optional[Dict] = None
    
    # Quantization Configuration  
    use_quantization: bool = True
    quantization_config: Optional[Dict] = None
    
    # Model Loading Configuration
    model_kwargs: Optional[Dict] = None
    
    # Generation Configuration (for evaluation/testing)
    generation_config: Optional[Dict] = None

    # Preference-training (DPO) configuration
    chosen_field: Optional[str] = None  # Field name for preferred response (for DPO datasets)
    rejected_field: Optional[str] = None  # Field name for rejected response (for DPO datasets)
    dpo_beta: float = 0.1  # DPO beta parameter
    
    # ============================================================================
    # MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
    # ============================================================================
    # Language Support (for multilingual datasets)
    primary_language: str = "fr"  # Primary language code
    reasoning_languages: Optional[List[str]] = None  # Supported languages for reasoning
    
    # Domain-specific settings
    domain_focus: Optional[str] = None  # "reasoning", "conversation", "instruction", "general"
    
    # ============================================================================
    # PERFORMANCE & MEMORY OPTIMIZATION
    # ============================================================================
    # Data Loading
    dataloader_num_workers: int = 4  # Number of data loading workers
    dataloader_pin_memory: bool = True  # Pin memory for faster GPU transfer
    dataloader_prefetch_factor: int = 2  # Prefetch factor for data loading
    dataset_num_proc: Optional[int] = None  # Parallel CPU processes for datasets map/filter ops
    
    # Memory Management
    max_memory_per_gpu: Optional[str] = None  # e.g., "80GB", "40GB"
    low_cpu_mem_usage: bool = True  # Use low CPU memory loading
    
    # Performance Optimizations
    group_by_length: bool = True  # Group sequences by length
    length_column_name: str = "length"  # Column name for sequence lengths
    remove_unused_columns: bool = True  # Remove unused dataset columns
    
    def __post_init__(self):
        """Initialize default values and validate configuration"""
        
        # ============================================================================
        # LORA CONFIGURATION DEFAULTS
        # ============================================================================
        if self.lora_config is None:
            self.lora_config = {
                "r": 16,  # Rank (4, 8, 16, 32, 64) - higher = more parameters
                "lora_alpha": 32,  # Scaling factor (usually 2*r)
                "target_modules": "all-linear",  # Apply LoRA to all linear layers
                "target_parameters": [
                    "7.mlp.experts.gate_up_proj",
                    "7.mlp.experts.down_proj", 
                    "15.mlp.experts.gate_up_proj",
                    "15.mlp.experts.down_proj",
                    "23.mlp.experts.gate_up_proj", 
                    "23.mlp.experts.down_proj",
                ],
                "bias": "none",  # "none", "all", "lora_only"
                "task_type": "CAUSAL_LM",
                "lora_dropout": 0.05,  # LoRA dropout rate
            }
        
        # ============================================================================
        # QUANTIZATION CONFIGURATION DEFAULTS
        # ============================================================================
        if self.quantization_config is None:
            self.quantization_config = {
                "dequantize": True,  # Use Mxfp4Config as per GPT-OSS tutorial
                "load_in_4bit": False,  # Set to True for extreme memory optimization
                "bnb_4bit_compute_dtype": "bfloat16",  # For 4-bit quantization
                "bnb_4bit_use_double_quant": True,  # Double quantization
                "bnb_4bit_quant_type": "nf4"  # Quantization type
            }
        
        # ============================================================================
        # MODEL LOADING CONFIGURATION DEFAULTS
        # ============================================================================
        if self.model_kwargs is None:
            self.model_kwargs = {
                "attn_implementation": "eager",  # "eager", "flash_attention_2"
                "torch_dtype": "auto",  # "auto", "bfloat16", "float16"
                "use_cache": False,  # Disable KV cache for training
                "device_map": "auto",  # Automatic device mapping
                "low_cpu_mem_usage": self.low_cpu_mem_usage,
            }
            
            # Add memory constraints if specified
            if self.max_memory_per_gpu:
                self.model_kwargs["max_memory"] = {0: self.max_memory_per_gpu}
        
        # ============================================================================
        # GENERATION CONFIGURATION DEFAULTS
        # ============================================================================
        if self.generation_config is None:
            self.generation_config = {
                "max_new_tokens": 512,  # Maximum tokens to generate
                "do_sample": True,  # Use sampling
                "temperature": 0.7,  # Sampling temperature
                "top_p": 0.9,  # Nucleus sampling
                "top_k": 50,  # Top-k sampling
                "repetition_penalty": 1.1,  # Repetition penalty
                "pad_token_id": None,  # Will be set from tokenizer
                "eos_token_id": None,  # Will be set from tokenizer
            }
        
        # ============================================================================
        # LANGUAGE CONFIGURATION DEFAULTS
        # ============================================================================
        if self.reasoning_languages is None:
            if self.primary_language == "fr":
                self.reasoning_languages = [
                    "French", "English", "Spanish", "Italian", "German"
                ]
            else:
                self.reasoning_languages = [
                    "English", "Spanish", "French", "Italian", "German", 
                    "Chinese", "Hindi", "Japanese", "Korean", "Arabic"
                ]
        
        # ============================================================================
        # SCHEDULER CONFIGURATION DEFAULTS
        # ============================================================================
        if self.lr_scheduler_kwargs is None:
            # Leave empty; training script will add TRL-specific keys only when needed
            self.lr_scheduler_kwargs = {}
        
        # ============================================================================
        # CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
        # ============================================================================
        if self.chat_template_kwargs is None:
            self.chat_template_kwargs = {
                "add_generation_prompt": True,
                "tokenize": False,
                "auto_insert_role": True,
                # GPT-OSS Harmony Format specific settings
                "reasoning_effort": "medium",  # low, medium, high
                "model_identity": "You are GPT-Tonic, a large language model trained by TonicAI.",
                "builtin_tools": [],  # Can include "browser" and/or "python"
            }
        
        # ============================================================================
        # VALIDATION AND COMPUTED VALUES
        # ============================================================================
        # Compute effective batch size
        effective_batch_size = self.batch_size * self.gradient_accumulation_steps
        
        # Set warmup steps if not provided
        if self.warmup_steps is None and self.max_steps:
            self.warmup_steps = int(self.max_steps * self.warmup_ratio)
        
        # Set max_length for dataset filtering
        if self.max_length is None:
            self.max_length = self.max_seq_length
        
        # Validate configuration
        self._validate_config()
        
        # Print comprehensive configuration summary
        self._print_config_summary(effective_batch_size)
    
    def _validate_config(self):
        """Validate configuration parameters"""
        
        # Validate batch configuration
        if self.batch_size < 1:
            raise ValueError("batch_size must be >= 1")
        if self.gradient_accumulation_steps < 1:
            raise ValueError("gradient_accumulation_steps must be >= 1")
            
        # Validate learning rate
        if self.learning_rate <= 0:
            raise ValueError("learning_rate must be > 0")
        if self.min_lr >= self.learning_rate:
            raise ValueError("min_lr must be < learning_rate")
            
        # Validate sequence length
        if self.max_seq_length < 1:
            raise ValueError("max_seq_length must be >= 1")
            
        # Validate dataset format
        valid_formats = ["openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"]
        if self.dataset_format not in valid_formats:
            raise ValueError(f"dataset_format must be one of {valid_formats}")
    
    def _print_config_summary(self, effective_batch_size):
        """Print detailed configuration summary"""
        
        print("\n" + "="*80)
        print("πŸš€ GPT-OSS ENHANCED CUSTOM CONFIGURATION")
        print("="*80)
        
        print(f"πŸ“Š Model & Training:")
        print(f"   β€’ Model: {self.model_name}")
        print(f"   β€’ Dataset: {self.dataset_name} ({self.dataset_format})")
        print(f"   β€’ Primary Language: {self.primary_language}")
        print(f"   β€’ Sequence Length: {self.max_seq_length}")
        print(f"   β€’ Epochs: {self.num_train_epochs}")
        
        print(f"\nπŸ”„ Batch Configuration:")
        print(f"   β€’ Per-device Batch Size: {self.batch_size}")
        print(f"   β€’ Gradient Accumulation: {self.gradient_accumulation_steps}")
        print(f"   β€’ Effective Batch Size: {effective_batch_size}")
        
        print(f"\nπŸ“ˆ Learning Configuration:")
        print(f"   β€’ Learning Rate: {self.learning_rate}")
        print(f"   β€’ Min Learning Rate: {self.min_lr}")
        print(f"   β€’ Weight Decay: {self.weight_decay}")
        print(f"   β€’ Warmup Ratio: {self.warmup_ratio}")
        
        print(f"\nπŸŽ›οΈ LoRA Configuration:")
        print(f"   β€’ Rank: {self.lora_config['r']}")
        print(f"   β€’ Alpha: {self.lora_config['lora_alpha']}")
        print(f"   β€’ Target Modules: {self.lora_config['target_modules']}")
        
        print(f"\nπŸ“ Dataset Configuration:")
        print(f"   β€’ Input Field: {self.input_field}")
        print(f"   β€’ Target Field: {self.target_field}")
        print(f"   β€’ Filter Bad Entries: {self.filter_bad_entries}")
        print(f"   β€’ Max Samples: {self.max_samples or 'All'}")
        if self.system_message or self.developer_message:
            print("   β€’ Context messages set:")
            if self.system_message:
                print("     - system message: provided")
            if self.developer_message:
                print("     - developer message: provided")
        
        print(f"\nπŸ’Ύ Memory & Performance:")
        print(f"   β€’ Mixed Precision: {'BF16' if self.bf16 else 'FP32'}")
        print(f"   β€’ Gradient Checkpointing: {self.use_gradient_checkpointing}")
        print(f"   β€’ Data Workers: {self.dataloader_num_workers}")
        print(f"   β€’ Group by Length: {self.group_by_length}")
        
        print("="*80 + "\n")

# Create the config instance with OpenHermes-FR optimized defaults
config = GPTOSSEnhancedCustomConfig()