| | |
| | """ |
| | 训练过程中的Predict监控回调 |
| | 用于实时监控训练过程中predict的变化和对齐情况 |
| | """ |
| |
|
| | import json |
| | import os |
| | import numpy as np |
| | import torch |
| | from datetime import datetime |
| | from typing import Dict, List, Any, Optional, Tuple |
| | from transformers import TrainerCallback, TrainerState, TrainerControl |
| | from transformers.trainer_utils import PredictionOutput |
| |
|
| | from enhanced_label_debug import EnhancedLabelDebugger |
| |
|
| | class PredictMonitoringCallback(TrainerCallback): |
| | """训练过程中的Predict监控回调""" |
| | |
| | def __init__(self, |
| | model_name: str, |
| | log_interval: int = 10, |
| | save_predictions: bool = True, |
| | detailed_analysis: bool = True): |
| | """ |
| | 初始化监控回调 |
| | |
| | Args: |
| | model_name: 模型名称 |
| | log_interval: 日志记录间隔(每N步记录一次) |
| | save_predictions: 是否保存预测结果 |
| | detailed_analysis: 是否进行详细分析 |
| | """ |
| | self.model_name = model_name |
| | self.log_interval = log_interval |
| | self.save_predictions = save_predictions |
| | self.detailed_analysis = detailed_analysis |
| | |
| | |
| | self.debugger = EnhancedLabelDebugger( |
| | model_name=model_name, |
| | log_file=f"/home/ziqiang/LLaMA-Factory/training_predict_monitor_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" |
| | ) |
| | |
| | |
| | self.step_analyses = [] |
| | self.prediction_history = [] |
| | |
| | self.debugger.log_debug(f"🔧 Predict监控回调初始化完成") |
| | self.debugger.log_debug(f"📊 日志间隔: {log_interval}步") |
| | self.debugger.log_debug(f"💾 保存预测: {save_predictions}") |
| | self.debugger.log_debug(f"🔍 详细分析: {detailed_analysis}") |
| | |
| | def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| | """在每个训练步骤结束时调用""" |
| | if state.global_step % self.log_interval == 0: |
| | self.debugger.log_debug(f"\n🔄 训练步骤 {state.global_step} 监控") |
| | self.debugger.log_debug(f"{'=' * 60}") |
| | |
| | |
| | self.debugger.log_debug(f"📈 当前Loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") |
| | self.debugger.log_debug(f"📊 学习率: {state.log_history[-1].get('learning_rate', 'N/A') if state.log_history else 'N/A'}") |
| | self.debugger.log_debug(f"⏱️ 训练时间: {state.training_time:.2f}秒") |
| | |
| | def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| | """在评估时调用""" |
| | self.debugger.log_debug(f"\n📊 评估阶段监控") |
| | self.debugger.log_debug(f"{'=' * 60}") |
| | self.debugger.log_debug(f"🔄 评估步骤: {state.global_step}") |
| | |
| | |
| | if hasattr(kwargs, 'predict_results') and kwargs['predict_results'] is not None: |
| | self._analyze_predictions(kwargs['predict_results'], state.global_step) |
| | |
| | def on_predict(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| | """在预测时调用""" |
| | self.debugger.log_debug(f"\n🔮 预测阶段监控") |
| | self.debugger.log_debug(f"{'=' * 60}") |
| | |
| | |
| | predict_results = kwargs.get('predict_results') |
| | if predict_results is not None: |
| | self._analyze_predictions(predict_results, state.global_step) |
| | |
| | def _analyze_predictions(self, predict_results: PredictionOutput, step: int): |
| | """分析预测结果""" |
| | self.debugger.log_debug(f"📊 预测结果分析 - 步骤 {step}") |
| | |
| | |
| | predictions = predict_results.predictions |
| | labels = predict_results.label_ids |
| | |
| | if predictions is None or labels is None: |
| | self.debugger.log_debug("⚠️ 预测结果或标签为空") |
| | return |
| | |
| | |
| | if isinstance(predictions, torch.Tensor): |
| | predictions = predictions.cpu().numpy() |
| | if isinstance(labels, torch.Tensor): |
| | labels = labels.cpu().numpy() |
| | |
| | |
| | batch_size = len(predictions) |
| | self.debugger.log_debug(f"📦 批次大小: {batch_size}") |
| | |
| | for i in range(min(batch_size, 3)): |
| | self.debugger.log_debug(f"\n🔍 样本 {i+1} 分析:") |
| | |
| | pred_sample = predictions[i] |
| | label_sample = labels[i] |
| | |
| | |
| | pred_sample = self._remove_padding(pred_sample) |
| | label_sample = self._remove_padding(label_sample) |
| | |
| | |
| | if self.detailed_analysis: |
| | analysis = self.debugger.analyze_training_step( |
| | step=step, |
| | predictions=pred_sample.tolist(), |
| | labels=label_sample.tolist(), |
| | loss=predict_results.metrics.get('eval_loss', None) if hasattr(predict_results, 'metrics') else None |
| | ) |
| | |
| | |
| | self.step_analyses.append(analysis) |
| | |
| | |
| | self.prediction_history.append({ |
| | "step": step, |
| | "sample_idx": i, |
| | "predictions": pred_sample.tolist(), |
| | "labels": label_sample.tolist(), |
| | "timestamp": datetime.now().isoformat() |
| | }) |
| | |
| | |
| | if self.save_predictions: |
| | self._save_predictions(predict_results, step) |
| | |
| | def _remove_padding(self, tokens: np.ndarray, pad_token_id: int = None) -> np.ndarray: |
| | """移除padding tokens""" |
| | if pad_token_id is None: |
| | pad_token_id = self.debugger.tokenizer.pad_token_id |
| | |
| | |
| | non_pad_mask = tokens != pad_token_id |
| | if np.any(non_pad_mask): |
| | |
| | first_non_pad = np.argmax(non_pad_mask) |
| | last_non_pad = len(tokens) - 1 - np.argmax(non_pad_mask[::-1]) |
| | return tokens[first_non_pad:last_non_pad+1] |
| | else: |
| | return tokens |
| | |
| | def _save_predictions(self, predict_results: PredictionOutput, step: int): |
| | """保存预测结果""" |
| | output_dir = "/home/ziqiang/LLaMA-Factory/prediction_monitoring" |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | pred_file = os.path.join(output_dir, f"predictions_step_{step}.json") |
| | with open(pred_file, "w", encoding="utf-8") as f: |
| | json.dump({ |
| | "step": step, |
| | "timestamp": datetime.now().isoformat(), |
| | "predictions": predict_results.predictions.tolist() if isinstance(predict_results.predictions, np.ndarray) else predict_results.predictions, |
| | "label_ids": predict_results.label_ids.tolist() if isinstance(predict_results.label_ids, np.ndarray) else predict_results.label_ids, |
| | "metrics": predict_results.metrics if hasattr(predict_results, 'metrics') else {} |
| | }, f, ensure_ascii=False, indent=2) |
| | |
| | self.debugger.log_debug(f"💾 预测结果已保存到: {pred_file}") |
| | |
| | def on_train_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| | """训练结束时调用""" |
| | self.debugger.log_debug(f"\n🏁 训练结束监控") |
| | self.debugger.log_debug(f"{'=' * 60}") |
| | |
| | |
| | if self.step_analyses: |
| | self.debugger.save_analysis_summary( |
| | self.step_analyses, |
| | f"/home/ziqiang/LLaMA-Factory/training_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| | ) |
| | |
| | |
| | self._generate_training_trends() |
| | |
| | def _generate_training_trends(self): |
| | """生成训练趋势分析""" |
| | if not self.step_analyses: |
| | return |
| | |
| | self.debugger.log_debug(f"\n📈 训练趋势分析") |
| | self.debugger.log_debug(f"{'=' * 60}") |
| | |
| | |
| | steps = [analysis["step"] for analysis in self.step_analyses] |
| | losses = [analysis["loss"] for analysis in self.step_analyses if analysis["loss"] is not None] |
| | valid_match_percentages = [ |
| | analysis["alignment_analysis"]["valid_match_percentage"] |
| | for analysis in self.step_analyses |
| | ] |
| | |
| | if losses: |
| | self.debugger.log_debug(f"📉 Loss趋势: {min(losses):.6f} -> {max(losses):.6f}") |
| | |
| | if valid_match_percentages: |
| | self.debugger.log_debug(f"🎯 有效匹配率趋势: {min(valid_match_percentages):.1f}% -> {max(valid_match_percentages):.1f}%") |
| | |
| | |
| | trend_data = { |
| | "steps": steps, |
| | "losses": losses, |
| | "valid_match_percentages": valid_match_percentages, |
| | "analysis_time": datetime.now().isoformat() |
| | } |
| | |
| | trend_file = f"/home/ziqiang/LLaMA-Factory/training_trends_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| | with open(trend_file, "w", encoding="utf-8") as f: |
| | json.dump(trend_data, f, ensure_ascii=False, indent=2) |
| | |
| | self.debugger.log_debug(f"📊 趋势数据已保存到: {trend_file}") |
| |
|
| | def create_predict_monitoring_callback(model_name: str, **kwargs) -> PredictMonitoringCallback: |
| | """创建预测监控回调的工厂函数""" |
| | return PredictMonitoringCallback(model_name=model_name, **kwargs) |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | callback = create_predict_monitoring_callback( |
| | model_name="/data/models/Qwen3-8B", |
| | log_interval=5, |
| | save_predictions=True, |
| | detailed_analysis=True |
| | ) |
| | |
| | print("✅ Predict监控回调创建完成") |
| | print(f"📁 日志文件: {callback.debugger.log_file}") |
| | print(f"📊 监控间隔: {callback.log_interval}步") |
| |
|