SmolFactory / src /trainer.py
Tonic's picture
adds experiment id fix
c560f4f
"""
SmolLM3 Trainer
Handles the training loop and integrates with Hugging Face Trainer
"""
import os
import torch
import logging
from typing import Optional, Dict, Any
from transformers import Trainer, TrainingArguments
from trl import SFTTrainer
import json
# Import monitoring
from monitoring import create_monitor_from_config
logger = logging.getLogger(__name__)
class SmolLM3Trainer:
"""Trainer for SmolLM3 fine-tuning"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
init_from: str = "scratch",
use_sft_trainer: bool = True
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.init_from = init_from
self.use_sft_trainer = use_sft_trainer
# Initialize monitoring
self.monitor = create_monitor_from_config(config)
# Setup trainer
self.trainer = self._setup_trainer()
def _setup_trainer(self):
"""Setup the trainer"""
logger.info("Setting up trainer")
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Debug: Print training arguments
logger.info("Training arguments keys: %s", list(training_args.__dict__.keys()))
logger.info("Training arguments type: %s", type(training_args))
# Get datasets
logger.info("Getting train dataset...")
train_dataset = self.dataset.get_train_dataset()
logger.info("Train dataset: %s with %d samples", type(train_dataset), len(train_dataset))
logger.info("Getting eval dataset...")
eval_dataset = self.dataset.get_eval_dataset()
logger.info("Eval dataset: %s with %d samples", type(eval_dataset), len(eval_dataset))
# Get data collator
logger.info("Getting data collator...")
data_collator = self.dataset.get_data_collator()
logger.info("Data collator: %s", type(data_collator))
# Add monitoring callbacks
callbacks = []
# Add simple console callback for basic monitoring
from transformers import TrainerCallback
outer = self
class SimpleConsoleCallback(TrainerCallback):
def on_init_end(self, args, state, control, **kwargs):
"""Called when training initialization is complete"""
print("🔧 Training initialization completed")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log metrics to console"""
if logs and isinstance(logs, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
loss = logs.get('loss', 'N/A')
lr = logs.get('learning_rate', 'N/A')
# Fix format string error by ensuring proper type conversion
if isinstance(loss, (int, float)):
loss_str = f"{loss:.4f}"
else:
loss_str = str(loss)
if isinstance(lr, (int, float)):
lr_str = f"{lr:.2e}"
else:
lr_str = str(lr)
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
# Persist metrics via our monitor when Trackio callback isn't active
try:
if outer.monitor:
# Avoid double logging when Trackio callback is used
if not outer.monitor.enable_tracking:
outer.monitor.log_metrics(dict(logs), step if isinstance(step, int) else None)
outer.monitor.log_system_metrics(step if isinstance(step, int) else None)
except Exception as e:
logger.warning("SimpleConsoleCallback metrics persistence failed: %s", e)
def on_train_begin(self, args, state, control, **kwargs):
print("🚀 Training started!")
def on_train_end(self, args, state, control, **kwargs):
print("✅ Training completed!")
def on_save(self, args, state, control, **kwargs):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
print(f"💾 Checkpoint saved at step {step}")
try:
if outer.monitor and not outer.monitor.enable_tracking:
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{step}")
if os.path.exists(checkpoint_path):
outer.monitor.log_model_checkpoint(checkpoint_path, step if isinstance(step, int) else None)
except Exception as e:
logger.warning("SimpleConsoleCallback checkpoint persistence failed: %s", e)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and isinstance(metrics, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
eval_loss = metrics.get('eval_loss', 'N/A')
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
try:
if outer.monitor and not outer.monitor.enable_tracking:
outer.monitor.log_evaluation_results(dict(metrics), step if isinstance(step, int) else None)
except Exception as e:
logger.warning("SimpleConsoleCallback eval persistence failed: %s", e)
# Add console callback
callbacks.append(SimpleConsoleCallback())
logger.info("Added simple console monitoring callback")
# Add monitoring callback if available (always attach; it persists to dataset even if Trackio is disabled)
if self.monitor:
try:
trackio_callback = self.monitor.create_monitoring_callback()
if trackio_callback:
callbacks.append(trackio_callback)
logger.info("Added monitoring callback")
else:
logger.warning("Failed to create monitoring callback")
except Exception as e:
logger.error("Error creating monitoring callback: %s", e)
logger.info("Continuing with console monitoring only")
logger.info("Total callbacks: %d", len(callbacks))
# Initialize trackio for TRL compatibility without creating a second experiment
try:
import trackio
if self.monitor:
# Share the same monitor/experiment with the trackio shim
try:
trackio.set_monitor(self.monitor) # type: ignore[attr-defined]
except Exception:
# Fallback: ensure the shim at least knows the current ID
pass
logger.info(
"Using shared Trackio monitor with experiment ID: %s",
getattr(self.monitor, 'experiment_id', None)
)
else:
# Last resort: initialize via shim
_ = trackio.init(
project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
trackio_url=getattr(self.config, 'trackio_url', None),
trackio_token=getattr(self.config, 'trackio_token', None),
hf_token=getattr(self.config, 'hf_token', None),
dataset_repo=getattr(self.config, 'dataset_repo', None)
)
except Exception as e:
logger.warning(f"Failed to wire trackio shim: {e}")
logger.info("Continuing without trackio shim integration")
# Try SFTTrainer first (better for instruction tuning)
logger.info("Creating SFTTrainer with training arguments...")
logger.info("Training args type: %s", type(training_args))
try:
trainer = SFTTrainer(
model=self.model.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using SFTTrainer (optimized for instruction tuning)")
except Exception as e:
logger.warning("SFTTrainer failed: %s", e)
logger.error("SFTTrainer creation error details: %s: %s", type(e).__name__, str(e))
# Fallback to standard Trainer
try:
trainer = Trainer(
model=self.model.model,
tokenizer=self.model.tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using standard Hugging Face Trainer (fallback)")
except Exception as e2:
logger.error("Standard Trainer also failed: %s", e2)
raise e2
return trainer
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint for resuming training"""
logger.info("Loading checkpoint from %s", checkpoint_path)
if self.init_from == "resume":
# Load the model from checkpoint
self.model.load_checkpoint(checkpoint_path)
# Update trainer with loaded model
self.trainer.model = self.model.model
logger.info("Checkpoint loaded successfully")
elif self.init_from == "pretrained":
# Model is already loaded from pretrained
logger.info("Using pretrained model")
else:
logger.info("Starting from scratch")
def train(self):
"""Start training"""
logger.info("Starting training")
# Log configuration (always persist to dataset; Trackio if enabled)
if self.monitor:
try:
config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')}
self.monitor.log_config(config_dict)
except Exception as e:
logger.warning("Failed to log configuration: %s", e)
# Log experiment URL only if available
try:
experiment_url = self.monitor.get_experiment_url()
if experiment_url:
logger.info("Trackio experiment URL: %s", experiment_url)
except Exception:
pass
# Load checkpoint if resuming
if self.init_from == "resume":
checkpoint_path = "/input-checkpoint"
if os.path.exists(checkpoint_path):
self.load_checkpoint(checkpoint_path)
else:
logger.warning("Checkpoint path %s not found, starting from scratch", checkpoint_path)
# Start training
try:
logger.info("About to start trainer.train()")
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
# Log training summary (always persist to dataset; Trackio if enabled)
if self.monitor:
try:
summary = {
'final_loss': train_result.metrics.get('train_loss', 0),
'total_steps': train_result.metrics.get('train_runtime', 0),
'training_time': train_result.metrics.get('train_runtime', 0),
'output_dir': self.output_dir,
'model_name': getattr(self.config, 'model_name', 'unknown'),
}
self.monitor.log_training_summary(summary)
self.monitor.close()
except Exception as e:
logger.warning("Failed to log training summary: %s", e)
# Finish trackio experiment
try:
import trackio
trackio.finish()
logger.info("Trackio experiment finished")
except Exception as e:
logger.warning(f"Failed to finish trackio experiment: {e}")
logger.info("Training completed successfully!")
logger.info("Training metrics: %s", train_result.metrics)
except Exception as e:
logger.error("Training failed: %s", e)
# Close monitoring on error (still persist final status to dataset)
if self.monitor:
try:
self.monitor.close(final_status="failed")
except Exception:
pass
# Finish trackio experiment on error
try:
import trackio
trackio.finish()
except Exception as finish_error:
logger.warning(f"Failed to finish trackio experiment on error: {finish_error}")
raise
def evaluate(self):
"""Evaluate the model"""
logger.info("Starting evaluation")
try:
eval_results = self.trainer.evaluate()
# Save evaluation results
with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
json.dump(eval_results, f, indent=2)
logger.info("Evaluation completed: %s", eval_results)
return eval_results
except Exception as e:
logger.error("Evaluation failed: %s", e)
raise
def save_model(self, path: Optional[str] = None):
"""Save the trained model"""
save_path = path or self.output_dir
logger.info("Saving model to %s", save_path)
try:
self.trainer.save_model(save_path)
self.model.tokenizer.save_pretrained(save_path)
# Save training configuration
if self.config:
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
json.dump(config_dict, f, indent=2, default=str)
logger.info("Model saved successfully!")
except Exception as e:
logger.error("Failed to save model: %s", e)
raise
class SmolLM3DPOTrainer:
"""DPO Trainer for SmolLM3 preference optimization"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
ref_model=None
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.ref_model = ref_model
# Setup DPO trainer
self.trainer = self._setup_dpo_trainer()
def _setup_dpo_trainer(self):
"""Setup DPO trainer"""
from trl import DPOTrainer
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Get preference dataset
train_dataset = self.dataset.get_train_dataset()
eval_dataset = self.dataset.get_eval_dataset()
# Setup DPO trainer
trainer = DPOTrainer(
model=self.model.model,
ref_model=self.ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.model.tokenizer,
max_prompt_length=self.config.max_seq_length // 2,
max_length=self.config.max_seq_length,
)
return trainer
def train(self):
"""Start DPO training"""
logger.info("Starting DPO training")
try:
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
logger.info("DPO training completed successfully!")
logger.info("Training metrics: %s", train_result.metrics)
except Exception as e:
logger.error("DPO training failed: %s", e)
raise