SmolFactory / scripts /training /train_gpt_oss.py
Tonic's picture
improves spaces deployment , configuration for custom settings , adds interface for spaces deployment
5f8b28d
#!/usr/bin/env python3
"""
GPT-OSS Training Script
Specialized training script for OpenAI's GPT-OSS models
Based on the GPT-OSS fine-tuning tutorial
"""
import os
import sys
import argparse
import inspect
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
try:
from trl import DPOTrainer
except Exception: # pragma: no cover - optional import depending on TRL version
DPOTrainer = None
from datasets import load_dataset
from pathlib import Path
# Import monitoring utilities from project src for persistent logging
try:
from src.monitoring import create_monitor_from_config # type: ignore
except Exception:
create_monitor_from_config = None # type: ignore
# Ensure project root and config package are importable for configs that do `from config...` imports
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
# Put project root early so top-level packages like `config` can be resolved
sys.path.insert(0, str(project_root))
config_dir = project_root / "config"
if str(config_dir) not in sys.path:
# Ensure the actual `config` package takes precedence over any `config.py` module elsewhere
sys.path.insert(0, str(config_dir))
# Ensure 'src' is importable for modules like 'monitoring', 'model', etc., but do not shadow `config`
src_dir = project_root / "src"
if str(src_dir) not in sys.path:
# Append to the end to avoid overshadowing the `config` package with `src/config.py`
sys.path.append(str(src_dir))
# If a stray 'config' module (e.g., from src/config.py) is already imported, remove it so
# that the real package `config/` (with __init__.py) can be imported with submodules.
try:
if 'config' in sys.modules and not hasattr(sys.modules['config'], '__path__'):
del sys.modules['config']
except Exception:
pass
# Reduce tokenizer thread contention and improve CUDA allocator behavior
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
def load_gpt_oss_model_and_tokenizer(config):
"""Load GPT-OSS model and tokenizer with proper configuration"""
print("Loading GPT-OSS tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print("Loading GPT-OSS model with quantization...")
# Import quantization config
from transformers import BitsAndBytesConfig
# Set up quantization config based on config
if config.quantization_config and config.quantization_config.get("load_in_4bit"):
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif config.quantization_config and (
config.quantization_config.get("dequantize")
or (
isinstance(config.quantization_config.get("mxfp4_config"), dict)
and config.quantization_config["mxfp4_config"].get("enabled", False)
)
):
# Try to use Mxfp4Config if available (as per tutorial)
try:
from transformers import Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
except ImportError:
# Fallback to no quantization if Mxfp4Config not available
print("Warning: Mxfp4Config not available, using no quantization")
quantization_config = None
else:
# No quantization
quantization_config = None
# Build model kwargs with sensible defaults and allow config overrides
default_model_kwargs = {
"attn_implementation": "eager",
"torch_dtype": torch.bfloat16,
"use_cache": False,
"device_map": "auto",
}
cfg_model_kwargs = getattr(config, "model_kwargs", None)
if isinstance(cfg_model_kwargs, dict):
# Config overrides defaults (e.g., attn_implementation="kernels-community/vllm-flash-attn3")
model_kwargs = {**default_model_kwargs, **cfg_model_kwargs}
else:
model_kwargs = default_model_kwargs.copy()
# Normalize torch_dtype if provided as a string in config
if isinstance(model_kwargs.get("torch_dtype"), str):
dtype_str = str(model_kwargs["torch_dtype"]).lower()
if dtype_str in {"bf16", "bfloat16"}:
model_kwargs["torch_dtype"] = torch.bfloat16
elif dtype_str in {"fp16", "float16", "half"}:
model_kwargs["torch_dtype"] = torch.float16
elif dtype_str == "auto":
# Leave as-is for HF to decide
pass
else:
# Fallback to bfloat16 for safer memory footprint on A100/H100
model_kwargs["torch_dtype"] = torch.bfloat16
# Ensure we have an offload folder for tight-memory setups
model_kwargs.setdefault("offload_folder", os.path.join(str(project_root), "offload"))
# Only add quantization_config if it's not None
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
# If using MXFP4, follow tutorial exactly: eager attention + bf16
try:
from transformers import Mxfp4Config as _Mxfp4Config
if isinstance(quantization_config, _Mxfp4Config):
model_kwargs["attn_implementation"] = "eager"
model_kwargs["torch_dtype"] = torch.bfloat16
model_kwargs["use_cache"] = False
model_kwargs["device_map"] = model_kwargs.get("device_map", "auto")
model_kwargs["quantization_config"] = quantization_config
except Exception:
pass
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
return model, tokenizer
def setup_lora_for_gpt_oss(model, config):
"""Setup LoRA for GPT-OSS model"""
print("Setting up LoRA for GPT-OSS...")
# LoRA configuration as per tutorial
lora_config = LoraConfig(
r=config.lora_config.get("r", 8) if config.lora_config else 8,
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16,
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear",
target_parameters=config.lora_config.get("target_parameters", [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
]) if config.lora_config else [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
return peft_model
def load_dataset_from_config(config):
"""Load dataset based on configuration"""
dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking')
dataset_split = getattr(config, 'dataset_split', 'train')
dataset_config = getattr(config, 'dataset_config', None)
print(f"Loading dataset: {dataset_name}")
print(f"Dataset split: {dataset_split}")
if dataset_config:
print(f"Dataset config: {dataset_config}")
# Load the dataset
if dataset_config:
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split)
else:
dataset = load_dataset(dataset_name, split=dataset_split)
print(f"Original dataset size: {len(dataset)} examples")
# Apply filtering based on configuration
dataset = apply_dataset_filtering(dataset, config)
# Apply dataset processing based on format
dataset = process_dataset_format(dataset, config)
print(f"Final dataset size: {len(dataset)} examples")
return dataset
def build_scheduler_kwargs(config):
"""Construct lr_scheduler_kwargs compatibly across TRL/Transformers versions.
- For TRL's 'cosine_with_min_lr' scheduler, ensure a min_lr/min_lr_rate is set.
- For all other schedulers, strip TRL-specific keys to avoid unexpected kwargs
errors in Transformers' native schedulers.
"""
skw = getattr(config, 'lr_scheduler_kwargs', {}) or {}
if not isinstance(skw, dict):
skw = {}
scheduler_type = getattr(config, 'scheduler', None)
# If we're NOT using TRL's special scheduler, drop incompatible keys early
if scheduler_type != 'cosine_with_min_lr':
for k in ('min_lr', 'min_lr_rate', 'warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
if k in skw:
skw.pop(k, None)
return skw
# TRL cosine-with-min-lr: ensure one of min_lr or min_lr_rate is provided
min_lr_cfg = getattr(config, 'min_lr', 1e-6)
if 'min_lr' not in skw and 'min_lr_rate' not in skw:
try:
if min_lr_cfg is not None:
skw['min_lr'] = float(min_lr_cfg)
else:
skw['min_lr_rate'] = 0.1
except Exception:
skw['min_lr_rate'] = 0.001
# Remove warmup-related keys which conflict with some TRL schedulers
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
if k in skw:
skw.pop(k, None)
return skw
def apply_dataset_filtering(dataset, config):
"""Apply filtering based on configuration"""
# Parallel workers for datasets ops
try:
import os as _os
num_proc = getattr(config, 'dataset_num_proc', None) or (_os.cpu_count() or 1)
except Exception:
num_proc = 1
# Filter bad entries if specified
if getattr(config, 'filter_bad_entries', False):
bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry')
bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected')
bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected')
original_size = len(dataset)
# Filter out bad entries
if bad_entry_field in dataset.column_names:
def _keep_not_bad_entry(example, _field=bad_entry_field):
return not example.get(_field, False)
dataset = dataset.filter(_keep_not_bad_entry, num_proc=num_proc)
print(f"Filtered {original_size - len(dataset)} bad entries")
# Filter out bad prompts
if bad_prompt_field in dataset.column_names:
def _keep_not_bad_prompt(example, _field=bad_prompt_field):
return not example.get(_field, False)
dataset = dataset.filter(_keep_not_bad_prompt, num_proc=num_proc)
print(f"Filtered bad prompts, remaining: {len(dataset)} examples")
# Filter out bad responses
if bad_response_field in dataset.column_names:
def _keep_not_bad_response(example, _field=bad_response_field):
return not example.get(_field, False)
dataset = dataset.filter(_keep_not_bad_response, num_proc=num_proc)
print(f"Filtered bad responses, remaining: {len(dataset)} examples")
# Apply length filtering
min_length = getattr(config, 'min_length', 10)
max_length = getattr(config, 'max_length', None)
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
if min_length > 0 or max_length:
def length_filter(example):
input_len = len(example.get(input_field, ''))
target_len = len(example.get(target_field, ''))
total_len = input_len + target_len
if total_len < min_length:
return False
if max_length and total_len > max_length:
return False
return True
original_size = len(dataset)
dataset = dataset.filter(length_filter, num_proc=num_proc)
print(f"Length filtering: {original_size} -> {len(dataset)} examples")
# Apply sampling if specified
max_samples = getattr(config, 'max_samples', None)
if max_samples and len(dataset) > max_samples:
dataset = dataset.shuffle(seed=42).select(range(max_samples))
print(f"Sampled {max_samples} examples from dataset")
return dataset
def _build_harmony_text(
user_content: str,
assistant_content: str,
add_eos_token: bool = True,
system_message: str | None = None,
developer_message: str | None = None,
) -> str:
"""Compose a Harmony-formatted conversation with optional system/developer messages.
Structure (training):
<|start|>system<|message|>...<|end|> (optional)
<|start|>developer<|message|>...<|end|> (optional)
<|start|>user<|message|>...<|end|>
<|start|>assistant<|channel|>final<|message|>...<|return|>
"""
parts: list[str] = []
if system_message:
parts.append(f"<|start|>system<|message|>{system_message}<|end|>")
if developer_message:
parts.append(f"<|start|>developer<|message|>{developer_message}<|end|>")
parts.append(f"<|start|>user<|message|>{user_content}<|end|>")
parts.append(f"<|start|>assistant<|channel|>final<|message|>{assistant_content}")
if add_eos_token:
parts[-1] += "<|return|>"
else:
parts[-1] += "<|end|>"
return "".join(parts)
def format_gpt_oss_harmony(
prompt: str,
completion: str,
add_eos_token: bool = True,
system_message: str | None = None,
developer_message: str | None = None,
) -> str:
"""
Format data for GPT-OSS Harmony format following the exact template structure.
Spec: `https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja`.
"""
return _build_harmony_text(
user_content=prompt,
assistant_content=completion,
add_eos_token=add_eos_token,
system_message=system_message,
developer_message=developer_message,
)
def format_gpt_oss_harmony_prompt(
prompt: str,
system_message: str | None = None,
developer_message: str | None = None,
) -> str:
"""Prefix-only Harmony prompt up to assistant content marker for DPO, with optional context."""
parts: list[str] = []
if system_message:
parts.append(f"<|start|>system<|message|>{system_message}<|end|>")
if developer_message:
parts.append(f"<|start|>developer<|message|>{developer_message}<|end|>")
parts.append(f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>")
return "".join(parts)
def process_dataset_format(dataset, config):
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
# Parallel workers for datasets ops
try:
import os as _os
num_proc = getattr(config, 'dataset_num_proc', None) or (_os.cpu_count() or 1)
except Exception:
num_proc = 1
dataset_format = getattr(config, 'dataset_format', 'openhermes_fr')
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
concatenate_fields = getattr(config, 'concatenate_fields', True)
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
add_eos_token = getattr(config, 'add_eos_token', True)
use_harmony_format = getattr(config, 'use_harmony_format', True)
trainer_type = getattr(config, 'trainer_type', 'sft')
system_message = getattr(config, 'system_message', None)
developer_message = getattr(config, 'developer_message', None)
print(f"Processing dataset format: {dataset_format}")
print(f"Input field: {input_field}, Target field: {target_field}")
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
# Preference-format for DPO training (chosen/rejected pairs)
if trainer_type == 'dpo':
chosen_field = getattr(config, 'chosen_field', None)
rejected_field = getattr(config, 'rejected_field', None)
if dataset_format == 'preference':
# Expect columns present; optionally reformat to ensure only necessary columns
def id_map(example):
prompt_val = example.get(input_field, '')
chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
if use_harmony_format:
prompt_text = format_gpt_oss_harmony_prompt(
prompt_val,
system_message=system_message,
developer_message=developer_message,
)
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names, num_proc=num_proc)
return dataset
# Custom preference mapping via configured field names
if chosen_field and rejected_field:
def to_pref(example):
prompt_val = example.get(input_field, '')
chosen_val = example.get(chosen_field, '')
rejected_val = example.get(rejected_field, '')
if use_harmony_format:
prompt_text = format_gpt_oss_harmony_prompt(
prompt_val,
system_message=system_message,
developer_message=developer_message,
)
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
dataset = dataset.map(to_pref, remove_columns=dataset.column_names, num_proc=num_proc)
return dataset
# If we reach here, we don't have required fields for DPO
raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.")
if dataset_format == "openhermes_fr":
# Process OpenHermes-FR format: prompt + accepted_completion
def format_openhermes_fr(example):
prompt = example.get(input_field, '')
completion = example.get(target_field, '')
if concatenate_fields:
if use_harmony_format:
# Use exact GPT-OSS Harmony format from template
text = format_gpt_oss_harmony(
prompt,
completion,
add_eos_token,
system_message=system_message,
developer_message=developer_message,
)
else:
# Fallback to standard format with separator
text = prompt + field_separator + completion
if add_eos_token:
text += "</s>"
return {"text": text}
else:
# Keep separate for more advanced training setups
return {
"input": prompt,
"output": completion
}
dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names, num_proc=num_proc)
elif dataset_format == "messages":
# Process messages format (like HuggingFaceH4/Multilingual-Thinking)
def format_messages(example):
messages = example.get(input_field, [])
if use_harmony_format and len(messages) >= 2:
# Extract user and assistant messages for harmony format
user_message = ""
assistant_message = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "user":
user_message = content
elif role == "assistant":
assistant_message = content
if user_message and assistant_message:
# Use GPT-OSS Harmony format
text = format_gpt_oss_harmony(
user_message,
assistant_message,
add_eos_token,
system_message=system_message,
developer_message=developer_message,
)
else:
# Fallback to simple concatenation
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
else:
# Standard format - convert messages to simple text
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_messages, remove_columns=dataset.column_names, num_proc=num_proc)
elif dataset_format == "medical_o1_sft":
# Process Medical-o1 SFT format: Question | Complex_CoT | Response
# Defaults align with FreedomIntelligence/medical-o1-reasoning-SFT
question_field = getattr(config, 'question_field', input_field or 'Question')
reasoning_field = getattr(config, 'reasoning_field', 'Complex_CoT')
response_field = getattr(config, 'response_field', target_field or 'Response')
reason_prefix = getattr(config, 'reason_prefix', 'Reasoning: ')
answer_prefix = getattr(config, 'answer_prefix', 'Final Answer: ')
def format_medical(example):
q = example.get(question_field, '') or ''
cot = example.get(reasoning_field, '') or ''
ans = example.get(response_field, '') or ''
# Combine reasoning and final answer in a single assistant turn
assistant_text = "\n\n".join(
[s for s in [
f"{reason_prefix}{cot}".strip() if cot else '',
f"{answer_prefix}{ans}".strip() if ans else ''
] if s]
) or ans
if use_harmony_format:
text = format_gpt_oss_harmony(
q,
assistant_text,
add_eos_token,
system_message=system_message,
developer_message=developer_message,
)
else:
text = f"Q: {q}\n\n{assistant_text}"
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_medical, remove_columns=dataset.column_names, num_proc=num_proc)
elif dataset_format == "text":
# Process plain text format
text_field = input_field
def format_text(example):
text = example.get(text_field, '')
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_text, remove_columns=dataset.column_names, num_proc=num_proc)
elif dataset_format == "custom":
# Custom format - user handles this in their config
print("Using custom dataset format - no automatic processing")
return dataset
def split_dataset(dataset, config):
"""Create train/validation/test splits from a single dataset.
Defaults to 1% eval and 1% test if not specified.
"""
from datasets import Dataset
if not isinstance(dataset, Dataset):
# If it's already a DatasetDict, try to use its splits
try:
train_split = dataset["train"]
eval_split = dataset.get("validation") or dataset.get("eval")
test_split = dataset.get("test")
return train_split, eval_split, test_split
except Exception:
pass
eval_ratio = getattr(config, 'eval_ratio', 0.01)
test_ratio = getattr(config, 'test_ratio', 0.01)
# Clamp ratios to sane bounds
try:
eval_ratio = max(0.0, float(eval_ratio))
test_ratio = max(0.0, float(test_ratio))
if eval_ratio + test_ratio >= 0.9:
# Avoid extreme splits; cap combined at 0.2
scale = 0.2 / max(1e-9, (eval_ratio + test_ratio))
eval_ratio *= scale
test_ratio *= scale
except Exception:
eval_ratio, test_ratio = 0.01, 0.01
# No eval/test requested
if eval_ratio <= 0 and test_ratio <= 0:
return dataset, None, None
ds_shuffled = dataset.shuffle(seed=42)
# First carve out test split
if test_ratio > 0:
split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42)
train_part = split1["train"]
test_split = split1["test"]
else:
train_part = ds_shuffled
test_split = None
# Then carve out eval from remaining train
if eval_ratio > 0:
remaining_fraction = 1.0 - test_ratio
# Convert global eval fraction to fraction of remaining pool
relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio
split2 = train_part.train_test_split(test_size=relative_eval, seed=42)
train_split = split2["train"]
eval_split = split2["test"]
else:
train_split = train_part
eval_split = None
# Log sizes
try:
print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}")
except Exception:
pass
return train_split, eval_split, test_split
def setup_trackio_tracking(config):
"""Setup Trackio tracking if enabled"""
if not getattr(config, 'enable_tracking', False):
print("Trackio tracking disabled or URL not provided")
return None
# Resolve Trackio URL from config or environment
trackio_url = getattr(config, 'trackio_url', None) or os.environ.get('TRACKIO_URL') or os.environ.get('TRACKIO_SPACE_ID')
if not trackio_url:
print("Trackio tracking enabled but no TRACKIO_URL/TRACKIO_SPACE_ID provided; skipping Trackio setup")
return None
print(f"Setting up Trackio tracking: {trackio_url}")
# Import the correct TrackioAPIClient
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic'))
from trackio_api_client import TrackioAPIClient
# Initialize Trackio client using the correct API
trackio_client = TrackioAPIClient(
space_id=trackio_url,
hf_token=getattr(config, 'trackio_token', None) or os.environ.get('HF_TOKEN')
)
return trackio_client
def create_sft_config(config, output_dir):
"""Create enhanced SFTConfig for GPT-OSS training"""
print("Creating enhanced SFT configuration...")
# Helper coercion utilities to guarantee numeric types
def _as_int(value, default):
if value is None:
return int(default)
try:
return int(value)
except Exception:
return int(default)
def _as_float(value, default):
if value is None:
return float(default)
try:
return float(value)
except Exception:
return float(default)
# Extract training parameters from config with enhanced defaults and coercion
num_train_epochs = _as_float(getattr(config, 'num_train_epochs', 1.0), 1.0)
# Transformers expects max_steps default -1 (disabled). Some code compares > 0
raw_max_steps = getattr(config, 'max_steps', None)
max_steps = _as_int(raw_max_steps if raw_max_steps is not None else -1, -1)
warmup_ratio = _as_float(getattr(config, 'warmup_ratio', 0.03), 0.03)
# Ensure warmup_steps is an int; default 0 to avoid None comparisons in schedulers
warmup_steps = _as_int(getattr(config, 'warmup_steps', 0), 0)
# Learning rate configuration
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
# Allow CLI/env override of scheduler
lr_scheduler_type = os.environ.get('GPT_OSS_SCHEDULER', getattr(config, 'scheduler', 'cosine'))
lr_scheduler_kwargs = build_scheduler_kwargs(config)
# Detect TRL scheduler signature incompatibilities and fall back gracefully
# Some TRL versions call get_cosine_with_min_lr_schedule_with_warmup with
# 'warmup_steps' instead of 'num_warmup_steps', which raises:
# get_cosine_with_min_lr_schedule_with_warmup() got an unexpected keyword
# argument 'warmup_steps'
# To avoid this, we fallback to the standard 'cosine' scheduler and strip
# incompatible kwargs when the incompatible signature is detected.
if lr_scheduler_type == 'cosine_with_min_lr':
try:
from trl.trainer import utils as trl_utils # type: ignore
import inspect as _inspect
if hasattr(trl_utils, 'get_cosine_with_min_lr_schedule_with_warmup'):
_sig = _inspect.signature(trl_utils.get_cosine_with_min_lr_schedule_with_warmup)
# If the function does NOT accept 'warmup_steps' explicitly, some TRL versions
# still pass it internally as a kwarg, causing a TypeError. Fallback to 'cosine'.
if 'warmup_steps' not in _sig.parameters:
print("Warning: Incompatible TRL scheduler signature detected; falling back to 'cosine'.")
lr_scheduler_type = 'cosine'
lr_scheduler_kwargs = {}
else:
# Function missing; fallback
print("Warning: TRL min-lr cosine scheduler not available; falling back to 'cosine'.")
lr_scheduler_type = 'cosine'
lr_scheduler_kwargs = {}
except Exception:
# Any import/signature issues -> safe fallback
print("Warning: Unable to verify TRL scheduler; falling back to 'cosine'.")
lr_scheduler_type = 'cosine'
lr_scheduler_kwargs = {}
# Batch configuration
per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)
per_device_eval_batch_size = _as_int(getattr(config, 'eval_batch_size', per_device_train_batch_size), per_device_train_batch_size)
gradient_accumulation_steps = _as_int(getattr(config, 'gradient_accumulation_steps', 1), 1)
# Evaluation and logging
eval_strategy = getattr(config, 'eval_strategy', 'steps')
eval_steps = _as_int(getattr(config, 'eval_steps', 100), 100)
eval_accumulation_steps = _as_int(getattr(config, 'eval_accumulation_steps', 1), 1)
logging_steps = _as_int(getattr(config, 'logging_steps', 10), 10)
# Saving configuration
save_strategy = getattr(config, 'save_strategy', 'steps')
save_steps = _as_int(getattr(config, 'save_steps', 500), 500)
save_total_limit = _as_int(getattr(config, 'save_total_limit', 3), 3)
# Mixed precision
fp16 = bool(getattr(config, 'fp16', False))
bf16 = bool(getattr(config, 'bf16', True))
tf32 = bool(getattr(config, 'tf32', False))
# Regularization
weight_decay = _as_float(getattr(config, 'weight_decay', 0.01), 0.01)
max_grad_norm = _as_float(getattr(config, 'max_grad_norm', 1.0), 1.0)
# HuggingFace Hub integration
push_to_hub = getattr(config, 'push_to_hub', False)
print(f" • Epochs: {num_train_epochs}")
print(f" • Learning rate: {learning_rate}")
print(f" • Batch size: {per_device_train_batch_size}")
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
# Build kwargs dynamically to be compatible across transformers versions
ta_kwargs = {
# Training duration
"num_train_epochs": num_train_epochs,
"max_steps": max_steps,
# Learning rate
"learning_rate": learning_rate,
"lr_scheduler_type": lr_scheduler_type,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
"warmup_ratio": warmup_ratio,
"warmup_steps": warmup_steps,
# Batch configuration
"per_device_train_batch_size": per_device_train_batch_size,
"per_device_eval_batch_size": per_device_eval_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
# Model configuration
"gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True),
# Mixed precision
"fp16": fp16,
"bf16": bf16,
# Some versions support tf32
"tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None,
# Optimizer (optionally use fused AdamW if available through config)
"optim": getattr(config, 'optimizer', 'adamw_torch'),
# Regularization
"weight_decay": weight_decay,
"max_grad_norm": max_grad_norm,
# Evaluation (name may vary across versions)
"evaluation_strategy": eval_strategy,
"eval_steps": eval_steps,
"eval_accumulation_steps": eval_accumulation_steps,
# Logging
"logging_steps": logging_steps,
# Saving
"save_strategy": save_strategy,
"save_steps": save_steps,
"save_total_limit": save_total_limit,
# Output
"output_dir": output_dir,
# Data loading
"dataloader_num_workers": _as_int(getattr(config, 'dataloader_num_workers', 4), 4),
"dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True),
# Optional in some versions
"dataloader_prefetch_factor": _as_int(getattr(config, 'dataloader_prefetch_factor', 2), 2),
# Performance
"group_by_length": getattr(config, 'group_by_length', True),
"remove_unused_columns": getattr(config, 'remove_unused_columns', True),
# HuggingFace Hub
"push_to_hub": push_to_hub,
# Monitoring
"report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None),
}
# Drop any None-valued kwargs
ta_kwargs = {k: v for k, v in ta_kwargs.items() if v is not None}
# Adapt to transformers versions where 'evaluation_strategy' was renamed
try:
ta_sig = inspect.signature(TrainingArguments.__init__)
param_names = set(ta_sig.parameters.keys())
except Exception:
param_names = set()
if "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
# Move value to 'eval_strategy'
ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
elif "evaluation_strategy" not in param_names:
# If neither is supported, drop it
ta_kwargs.pop("evaluation_strategy", None)
# Remove any kwargs not supported by current transformers version
if param_names:
unsupported = [k for k in ta_kwargs.keys() if k not in param_names]
for k in unsupported:
ta_kwargs.pop(k, None)
sft_config = TrainingArguments(**ta_kwargs)
return sft_config
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"):
"""Main training function for GPT-OSS"""
print("=== GPT-OSS Training Pipeline ===")
print(f"Config: {config_path}")
print(f"Experiment: {experiment_name}")
print(f"Output: {output_dir}")
print(f"Trackio: {trackio_url}")
print(f"Trainer: {trainer_type}")
# Load configuration
if os.path.exists(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
config = config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name):
config = attr
break
else:
raise ValueError(f"No GPT-OSS configuration found in {config_path}")
else:
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Update config with runtime parameters
config.experiment_name = experiment_name
config.trackio_url = trackio_url
config.trainer_type = trainer_type
# Optional: scheduler overrides via environment variables set by CLI
try:
env_scheduler = os.environ.get("GPT_OSS_SCHEDULER")
if env_scheduler:
# Apply scheduler override
config.scheduler = env_scheduler
# Prepare/normalize lr scheduler kwargs container
if not hasattr(config, 'lr_scheduler_kwargs') or config.lr_scheduler_kwargs is None:
config.lr_scheduler_kwargs = {}
# Apply min lr overrides only when using TRL's special scheduler
if env_scheduler == 'cosine_with_min_lr':
env_min_lr = os.environ.get("GPT_OSS_MIN_LR")
env_min_lr_rate = os.environ.get("GPT_OSS_MIN_LR_RATE")
# Clear conflicting warmup keys to avoid signature issues
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
if k in config.lr_scheduler_kwargs:
config.lr_scheduler_kwargs.pop(k, None)
# Prefer absolute min_lr if provided
if env_min_lr is not None:
try:
config.min_lr = float(env_min_lr)
config.lr_scheduler_kwargs['min_lr'] = config.min_lr
# Remove relative rate if present
config.lr_scheduler_kwargs.pop('min_lr_rate', None)
except Exception:
pass
elif env_min_lr_rate is not None:
try:
config.lr_scheduler_kwargs['min_lr_rate'] = float(env_min_lr_rate)
# Remove absolute min_lr if present in kwargs (leave config.min_lr untouched)
config.lr_scheduler_kwargs.pop('min_lr', None)
except Exception:
pass
else:
# Ensure at least one constraint exists; prefer absolute from config if valid
try:
if hasattr(config, 'min_lr') and config.min_lr is not None:
config.lr_scheduler_kwargs['min_lr'] = float(config.min_lr)
else:
config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
except Exception:
config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
else:
# Non-TRL scheduler: strip TRL-specific keys to avoid unexpected kwargs
if hasattr(config, 'lr_scheduler_kwargs') and isinstance(config.lr_scheduler_kwargs, dict):
for k in ('min_lr', 'min_lr_rate'):
config.lr_scheduler_kwargs.pop(k, None)
except Exception:
pass
# Load model and tokenizer
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
# Setup LoRA
peft_model = setup_lora_for_gpt_oss(model, config)
# Load dataset
dataset = load_dataset_from_config(config)
# Split into train/eval/test
train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config)
# Ensure TRACKIO_URL env is set so SmolLM3Monitor picks it up
if trackio_url and not os.environ.get('TRACKIO_URL'):
os.environ['TRACKIO_URL'] = trackio_url
os.environ.setdefault('TRACKIO_SPACE_ID', trackio_url)
# Setup Trackio tracking (Space API client) and monitoring (dataset + Space)
trackio_client = setup_trackio_tracking(config)
# Create unified monitor to ensure metrics get logged to dataset/Space
monitor = None
try:
from monitoring import SmolLM3Monitor
monitor = SmolLM3Monitor(
experiment_name=experiment_name,
trackio_url=trackio_url,
trackio_token=getattr(config, 'trackio_token', None) or os.environ.get('HF_TOKEN'),
enable_tracking=True,
log_artifacts=True,
log_metrics=True,
log_config=True,
hf_token=os.environ.get('HF_TOKEN'),
dataset_repo=os.environ.get('TRACKIO_DATASET_REPO'),
monitoring_mode=os.environ.get('MONITORING_MODE', 'both'),
)
# Log configuration once
try:
cfg_dict = {k: getattr(config, k) for k in dir(config) if not k.startswith('_') and not callable(getattr(config, k))}
monitor.log_configuration(cfg_dict)
except Exception:
pass
except Exception as e:
print(f"Warning: failed to initialize monitor: {e}")
# Initialize project monitor (HF Datasets + Trackio Space if configured)
monitor_callback = None
if create_monitor_from_config is not None:
try:
project_monitor = create_monitor_from_config(config, experiment_name=experiment_name)
# Persist configuration immediately
try:
cfg_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
project_monitor.log_config(cfg_dict)
except Exception:
pass
# Create callback for SFTTrainer
monitor_callback = project_monitor.create_monitoring_callback()
# If we didn't initialize the explicit monitor above, use this one for summary/close
if monitor is None:
monitor = project_monitor
except Exception:
pass
# Create SFT configuration
sft_config = create_sft_config(config, output_dir)
# Create trainer with version-robust kwargs
if trainer_type == 'dpo':
if DPOTrainer is None:
raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.")
print("Creating DPO trainer...")
try:
dpo_sig = inspect.signature(DPOTrainer.__init__)
dpo_params = set(dpo_sig.parameters.keys())
except Exception:
dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"}
dpo_kwargs = {
"model": peft_model,
"args": sft_config,
"train_dataset": train_dataset,
"beta": getattr(config, 'dpo_beta', 0.1),
}
if "tokenizer" in dpo_params:
dpo_kwargs["tokenizer"] = tokenizer
elif "processing_class" in dpo_params:
dpo_kwargs["processing_class"] = tokenizer
if "prompt_column" in dpo_params:
dpo_kwargs["prompt_column"] = "prompt"
if "chosen_column" in dpo_params:
dpo_kwargs["chosen_column"] = "chosen"
if "rejected_column" in dpo_params:
dpo_kwargs["rejected_column"] = "rejected"
# Remove Nones
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None}
# Pass eval dataset if supported
if "eval_dataset" in dpo_params and eval_dataset is not None:
dpo_kwargs["eval_dataset"] = eval_dataset
trainer = DPOTrainer(**dpo_kwargs)
else:
print("Creating SFT trainer...")
try:
sft_sig = inspect.signature(SFTTrainer.__init__)
sft_params = set(sft_sig.parameters.keys())
except Exception:
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
sft_kwargs = {
"model": peft_model,
"args": sft_config,
"train_dataset": train_dataset,
}
# Prefer passing tokenizer if supported; otherwise try processing_class
if "tokenizer" in sft_params:
sft_kwargs["tokenizer"] = tokenizer
elif "processing_class" in sft_params:
sft_kwargs["processing_class"] = tokenizer
# Pass dataset text field if supported (we produced a 'text' column)
if "dataset_text_field" in sft_params:
sft_kwargs["dataset_text_field"] = "text"
# Pass max sequence length if supported
if "max_seq_length" in sft_params:
sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
# Enable sequence packing if supported by TRL (speeds up token utilization)
if "packing" in sft_params:
sft_kwargs["packing"] = getattr(config, 'packing', False)
# Attach monitoring callback if supported
if "callbacks" in sft_params:
sft_kwargs["callbacks"] = ([monitor_callback] if monitor_callback is not None else [])
# Attach monitoring callback if supported
if monitor is not None:
try:
if "callbacks" in sft_params:
sft_kwargs["callbacks"] = [monitor.create_monitoring_callback()]
except Exception:
pass
# Remove any None values
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
# Attach eval_dataset if supported
if "eval_dataset" in sft_params and eval_dataset is not None:
sft_kwargs["eval_dataset"] = eval_dataset
trainer = SFTTrainer(**sft_kwargs)
# Start training
print("Starting GPT-OSS training...")
try:
trainer.train()
finally:
# Ensure periodic metrics are flushed at the end even if interrupted
try:
if monitor is not None:
monitor._save_to_hf_dataset({'status': 'running'})
except Exception:
pass
# Save model
print("Saving trained model...")
trainer.save_model(output_dir)
# Push to hub if enabled
if sft_config.push_to_hub:
print("Pushing model to Hugging Face Hub...")
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
# Log training summary and close monitor
try:
if monitor is not None:
summary = {
'output_dir': output_dir,
'model_name': getattr(config, 'model_name', 'unknown'),
}
monitor.log_training_summary(summary)
monitor.close()
except Exception:
pass
# Close monitor cleanly
try:
if monitor is not None:
monitor.close()
except Exception:
pass
print("GPT-OSS training completed successfully!")
return trainer
def main():
parser = argparse.ArgumentParser(description="GPT-OSS Training Script")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument("--experiment-name", required=True, help="Experiment name")
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
# Optional LR scheduler overrides (applied across any GPT-OSS config)
parser.add_argument(
"--scheduler",
choices=["linear", "cosine", "cosine_with_min_lr", "constant"],
help="Override LR scheduler for this run",
)
parser.add_argument(
"--min-lr",
type=float,
dest="min_lr",
help="Absolute floor for LR (used when scheduler is 'cosine_with_min_lr')",
)
parser.add_argument(
"--min-lr-rate",
type=float,
dest="min_lr_rate",
help="Relative LR floor rate in (0,1) for TRL scheduler (used when scheduler is 'cosine_with_min_lr')",
)
args = parser.parse_args()
# Validate arguments
if not os.path.exists(args.config):
print(f"Error: Configuration file not found: {args.config}")
sys.exit(1)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
# If provided, expose scheduler overrides via environment so they can be picked up consistently
# across helper functions if needed.
if args.scheduler:
os.environ["GPT_OSS_SCHEDULER"] = args.scheduler
if args.min_lr is not None:
os.environ["GPT_OSS_MIN_LR"] = str(args.min_lr)
if args.min_lr_rate is not None:
os.environ["GPT_OSS_MIN_LR_RATE"] = str(args.min_lr_rate)
trainer = train_gpt_oss(
config_path=args.config,
experiment_name=args.experiment_name,
output_dir=args.output_dir,
trackio_url=args.trackio_url,
trainer_type=args.trainer_type
)
except Exception as e:
print(f"Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main()