| """ |
| Training callbacks for monitoring and checkpointing. |
| Integrates with Weights & Biases and TensorBoard. |
| """ |
|
|
| from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments |
| from loguru import logger |
|
|
| try: |
| import wandb |
| HAS_WANDB = True |
| except ImportError: |
| HAS_WANDB = False |
|
|
|
|
| class StyleMetricsCallback(TrainerCallback): |
| """Logs style similarity metrics during evaluation.""" |
|
|
| def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| metrics = kwargs.get("metrics", {}) |
| if metrics: |
| logger.info(f"Evaluation metrics at step {state.global_step}:") |
| for key, value in metrics.items(): |
| logger.info(f" {key}: {value:.4f}" if isinstance(value, float) else f" {key}: {value}") |
|
|
| |
| if HAS_WANDB and wandb.run is not None: |
| wandb.log( |
| {f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, |
| step=state.global_step, |
| ) |
|
|
|
|
| class EarlyStoppingOnStyleDrift(TrainerCallback): |
| """Stops training if style similarity drops below threshold.""" |
|
|
| def __init__(self, min_style_similarity: float = 0.75): |
| self.min_style_similarity = min_style_similarity |
| self.best_style_sim = 0.0 |
| self.patience_counter = 0 |
| self.patience = 3 |
|
|
| def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| metrics = kwargs.get("metrics", {}) |
| style_sim = metrics.get("eval_style_similarity", None) |
|
|
| if style_sim is not None: |
| if style_sim > self.best_style_sim: |
| self.best_style_sim = style_sim |
| self.patience_counter = 0 |
|
|
| if style_sim < self.min_style_similarity: |
| self.patience_counter += 1 |
| logger.warning( |
| f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. " |
| f"Patience: {self.patience_counter}/{self.patience}" |
| ) |
| if self.patience_counter >= self.patience: |
| logger.error( |
| f"Early stopping: style similarity consistently below {self.min_style_similarity}" |
| ) |
| control.should_training_stop = True |
| else: |
| self.patience_counter = 0 |
|
|