""" SmolLM3 Trainer Handles the training loop and integrates with Hugging Face Trainer """ import os import torch import logging from typing import Optional, Dict, Any from transformers import Trainer, TrainingArguments from trl import SFTTrainer import json # Import monitoring from monitoring import create_monitor_from_config logger = logging.getLogger(__name__) class SmolLM3Trainer: """Trainer for SmolLM3 fine-tuning""" def __init__( self, model, dataset, config, output_dir: str, init_from: str = "scratch", use_sft_trainer: bool = True ): self.model = model self.dataset = dataset self.config = config self.output_dir = output_dir self.init_from = init_from self.use_sft_trainer = use_sft_trainer # Initialize monitoring self.monitor = create_monitor_from_config(config) # Setup trainer self.trainer = self._setup_trainer() def _setup_trainer(self): """Setup the trainer""" logger.info("Setting up trainer") # Get training arguments training_args = self.model.get_training_arguments( output_dir=self.output_dir, save_steps=self.config.save_steps, eval_steps=self.config.eval_steps, logging_steps=self.config.logging_steps, max_steps=self.config.max_iters, ) # Debug: Print training arguments logger.info("Training arguments keys: %s", list(training_args.__dict__.keys())) logger.info("Training arguments type: %s", type(training_args)) # Get datasets logger.info("Getting train dataset...") train_dataset = self.dataset.get_train_dataset() logger.info("Train dataset: %s with %d samples", type(train_dataset), len(train_dataset)) logger.info("Getting eval dataset...") eval_dataset = self.dataset.get_eval_dataset() logger.info("Eval dataset: %s with %d samples", type(eval_dataset), len(eval_dataset)) # Get data collator logger.info("Getting data collator...") data_collator = self.dataset.get_data_collator() logger.info("Data collator: %s", type(data_collator)) # Add monitoring callbacks callbacks = [] # Add simple console callback for basic monitoring from transformers import TrainerCallback outer = self class SimpleConsoleCallback(TrainerCallback): def on_init_end(self, args, state, control, **kwargs): """Called when training initialization is complete""" print("🔧 Training initialization completed") def on_log(self, args, state, control, logs=None, **kwargs): """Log metrics to console""" if logs and isinstance(logs, dict): step = state.global_step if hasattr(state, 'global_step') else 'unknown' loss = logs.get('loss', 'N/A') lr = logs.get('learning_rate', 'N/A') # Fix format string error by ensuring proper type conversion if isinstance(loss, (int, float)): loss_str = f"{loss:.4f}" else: loss_str = str(loss) if isinstance(lr, (int, float)): lr_str = f"{lr:.2e}" else: lr_str = str(lr) print(f"Step {step}: loss={loss_str}, lr={lr_str}") # Persist metrics via our monitor when Trackio callback isn't active try: if outer.monitor: # Avoid double logging when Trackio callback is used if not outer.monitor.enable_tracking: outer.monitor.log_metrics(dict(logs), step if isinstance(step, int) else None) outer.monitor.log_system_metrics(step if isinstance(step, int) else None) except Exception as e: logger.warning("SimpleConsoleCallback metrics persistence failed: %s", e) def on_train_begin(self, args, state, control, **kwargs): print("🚀 Training started!") def on_train_end(self, args, state, control, **kwargs): print("✅ Training completed!") def on_save(self, args, state, control, **kwargs): step = state.global_step if hasattr(state, 'global_step') else 'unknown' print(f"💾 Checkpoint saved at step {step}") try: if outer.monitor and not outer.monitor.enable_tracking: checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{step}") if os.path.exists(checkpoint_path): outer.monitor.log_model_checkpoint(checkpoint_path, step if isinstance(step, int) else None) except Exception as e: logger.warning("SimpleConsoleCallback checkpoint persistence failed: %s", e) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics and isinstance(metrics, dict): step = state.global_step if hasattr(state, 'global_step') else 'unknown' eval_loss = metrics.get('eval_loss', 'N/A') print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}") try: if outer.monitor and not outer.monitor.enable_tracking: outer.monitor.log_evaluation_results(dict(metrics), step if isinstance(step, int) else None) except Exception as e: logger.warning("SimpleConsoleCallback eval persistence failed: %s", e) # Add console callback callbacks.append(SimpleConsoleCallback()) logger.info("Added simple console monitoring callback") # Add monitoring callback if available (always attach; it persists to dataset even if Trackio is disabled) if self.monitor: try: trackio_callback = self.monitor.create_monitoring_callback() if trackio_callback: callbacks.append(trackio_callback) logger.info("Added monitoring callback") else: logger.warning("Failed to create monitoring callback") except Exception as e: logger.error("Error creating monitoring callback: %s", e) logger.info("Continuing with console monitoring only") logger.info("Total callbacks: %d", len(callbacks)) # Initialize trackio for TRL compatibility without creating a second experiment try: import trackio if self.monitor: # Share the same monitor/experiment with the trackio shim try: trackio.set_monitor(self.monitor) # type: ignore[attr-defined] except Exception: # Fallback: ensure the shim at least knows the current ID pass logger.info( "Using shared Trackio monitor with experiment ID: %s", getattr(self.monitor, 'experiment_id', None) ) else: # Last resort: initialize via shim _ = trackio.init( project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'), experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'), trackio_url=getattr(self.config, 'trackio_url', None), trackio_token=getattr(self.config, 'trackio_token', None), hf_token=getattr(self.config, 'hf_token', None), dataset_repo=getattr(self.config, 'dataset_repo', None) ) except Exception as e: logger.warning(f"Failed to wire trackio shim: {e}") logger.info("Continuing without trackio shim integration") # Try SFTTrainer first (better for instruction tuning) logger.info("Creating SFTTrainer with training arguments...") logger.info("Training args type: %s", type(training_args)) try: trainer = SFTTrainer( model=self.model.model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, data_collator=data_collator, callbacks=callbacks, ) logger.info("Using SFTTrainer (optimized for instruction tuning)") except Exception as e: logger.warning("SFTTrainer failed: %s", e) logger.error("SFTTrainer creation error details: %s: %s", type(e).__name__, str(e)) # Fallback to standard Trainer try: trainer = Trainer( model=self.model.model, tokenizer=self.model.tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, callbacks=callbacks, ) logger.info("Using standard Hugging Face Trainer (fallback)") except Exception as e2: logger.error("Standard Trainer also failed: %s", e2) raise e2 return trainer def load_checkpoint(self, checkpoint_path: str): """Load checkpoint for resuming training""" logger.info("Loading checkpoint from %s", checkpoint_path) if self.init_from == "resume": # Load the model from checkpoint self.model.load_checkpoint(checkpoint_path) # Update trainer with loaded model self.trainer.model = self.model.model logger.info("Checkpoint loaded successfully") elif self.init_from == "pretrained": # Model is already loaded from pretrained logger.info("Using pretrained model") else: logger.info("Starting from scratch") def train(self): """Start training""" logger.info("Starting training") # Log configuration (always persist to dataset; Trackio if enabled) if self.monitor: try: config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')} self.monitor.log_config(config_dict) except Exception as e: logger.warning("Failed to log configuration: %s", e) # Log experiment URL only if available try: experiment_url = self.monitor.get_experiment_url() if experiment_url: logger.info("Trackio experiment URL: %s", experiment_url) except Exception: pass # Load checkpoint if resuming if self.init_from == "resume": checkpoint_path = "/input-checkpoint" if os.path.exists(checkpoint_path): self.load_checkpoint(checkpoint_path) else: logger.warning("Checkpoint path %s not found, starting from scratch", checkpoint_path) # Start training try: logger.info("About to start trainer.train()") train_result = self.trainer.train() # Save the final model self.trainer.save_model() # Save training results with open(os.path.join(self.output_dir, "train_results.json"), "w") as f: json.dump(train_result.metrics, f, indent=2) # Log training summary (always persist to dataset; Trackio if enabled) if self.monitor: try: summary = { 'final_loss': train_result.metrics.get('train_loss', 0), 'total_steps': train_result.metrics.get('train_runtime', 0), 'training_time': train_result.metrics.get('train_runtime', 0), 'output_dir': self.output_dir, 'model_name': getattr(self.config, 'model_name', 'unknown'), } self.monitor.log_training_summary(summary) self.monitor.close() except Exception as e: logger.warning("Failed to log training summary: %s", e) # Finish trackio experiment try: import trackio trackio.finish() logger.info("Trackio experiment finished") except Exception as e: logger.warning(f"Failed to finish trackio experiment: {e}") logger.info("Training completed successfully!") logger.info("Training metrics: %s", train_result.metrics) except Exception as e: logger.error("Training failed: %s", e) # Close monitoring on error (still persist final status to dataset) if self.monitor: try: self.monitor.close(final_status="failed") except Exception: pass # Finish trackio experiment on error try: import trackio trackio.finish() except Exception as finish_error: logger.warning(f"Failed to finish trackio experiment on error: {finish_error}") raise def evaluate(self): """Evaluate the model""" logger.info("Starting evaluation") try: eval_results = self.trainer.evaluate() # Save evaluation results with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f: json.dump(eval_results, f, indent=2) logger.info("Evaluation completed: %s", eval_results) return eval_results except Exception as e: logger.error("Evaluation failed: %s", e) raise def save_model(self, path: Optional[str] = None): """Save the trained model""" save_path = path or self.output_dir logger.info("Saving model to %s", save_path) try: self.trainer.save_model(save_path) self.model.tokenizer.save_pretrained(save_path) # Save training configuration if self.config: config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')} with open(os.path.join(save_path, 'training_config.json'), 'w') as f: json.dump(config_dict, f, indent=2, default=str) logger.info("Model saved successfully!") except Exception as e: logger.error("Failed to save model: %s", e) raise class SmolLM3DPOTrainer: """DPO Trainer for SmolLM3 preference optimization""" def __init__( self, model, dataset, config, output_dir: str, ref_model=None ): self.model = model self.dataset = dataset self.config = config self.output_dir = output_dir self.ref_model = ref_model # Setup DPO trainer self.trainer = self._setup_dpo_trainer() def _setup_dpo_trainer(self): """Setup DPO trainer""" from trl import DPOTrainer # Get training arguments training_args = self.model.get_training_arguments( output_dir=self.output_dir, save_steps=self.config.save_steps, eval_steps=self.config.eval_steps, logging_steps=self.config.logging_steps, max_steps=self.config.max_iters, ) # Get preference dataset train_dataset = self.dataset.get_train_dataset() eval_dataset = self.dataset.get_eval_dataset() # Setup DPO trainer trainer = DPOTrainer( model=self.model.model, ref_model=self.ref_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=self.model.tokenizer, max_prompt_length=self.config.max_seq_length // 2, max_length=self.config.max_seq_length, ) return trainer def train(self): """Start DPO training""" logger.info("Starting DPO training") try: train_result = self.trainer.train() # Save the final model self.trainer.save_model() # Save training results with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f: json.dump(train_result.metrics, f, indent=2) logger.info("DPO training completed successfully!") logger.info("Training metrics: %s", train_result.metrics) except Exception as e: logger.error("DPO training failed: %s", e) raise