Spaces:
Paused
Paused
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix | |
| import pandas as pd | |
| def plot_confusion_matrix(y_true, y_pred, classes, title='Confusion Matrix', cmap=plt.cm.Blues): | |
| """ | |
| Vẽ Confusion Matrix chuyên nghiệp cho các câu hỏi Closed-ended (Yes/No). | |
| """ | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, | |
| xticklabels=classes, yticklabels=classes) | |
| plt.title(title, fontsize=15) | |
| plt.ylabel('Ground Truth', fontsize=12) | |
| plt.xlabel('Predicted', fontsize=12) | |
| plt.tight_layout() | |
| return plt | |
| def plot_radar_chart(model_names, metrics_data, categories, title='Model Comparison (All Variants)'): | |
| """ | |
| Vẽ biểu đồ Radar để so sánh 5 biến thể trên nhiều tiêu chí (Accuracy, BLEU, ROUGE, BERTScore). | |
| metrics_data: List of lists, mỗi list là chỉ số của 1 model. | |
| """ | |
| N = len(categories) | |
| angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True)) | |
| for i, model_name in enumerate(model_names): | |
| values = metrics_data[i] | |
| values += values[:1] | |
| ax.plot(angles, values, linewidth=2, linestyle='solid', label=model_name) | |
| ax.fill(angles, values, alpha=0.1) | |
| ax.set_theta_offset(np.pi / 2) | |
| ax.set_theta_direction(-1) | |
| plt.xticks(angles[:-1], categories, fontsize=12) | |
| plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
| plt.title(title, size=20, y=1.1) | |
| return plt | |
| def plot_training_history(history, title='Training History'): | |
| """ | |
| Vẽ đồ thị Loss và Accuracy trong quá trình huấn luyện. | |
| history: dict có keys 'train_loss', 'val_acc', v.v. | |
| """ | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Loss plot | |
| ax1.plot(history['train_loss'], label='Train Loss') | |
| if 'val_loss' in history: | |
| ax1.plot(history['val_loss'], label='Val Loss') | |
| ax1.set_title('Loss Evolution') | |
| ax1.set_xlabel('Epochs') | |
| ax1.set_ylabel('Loss') | |
| ax1.legend() | |
| ax1.grid(True) | |
| # Accuracy plot | |
| ax2.plot(history['val_acc'], label='Val Accuracy', color='green') | |
| ax2.set_title('Accuracy Evolution') | |
| ax2.set_xlabel('Epochs') | |
| ax2.set_ylabel('Accuracy') | |
| ax2.legend() | |
| ax2.grid(True) | |
| plt.suptitle(title, fontsize=16) | |
| plt.tight_layout() | |
| return plt | |
| def plot_benchmark_comparison(results_df, metric='Accuracy'): | |
| """ | |
| Biểu đồ cột so sánh một chỉ số cụ thể giữa các mô hình. | |
| results_df: DataFrame có cột 'Model' và các chỉ số. | |
| """ | |
| plt.figure(figsize=(10, 6)) | |
| sns.set_style("whitegrid") | |
| ax = sns.barplot(x='Model', y=metric, data=results_df, palette='viridis') | |
| for p in ax.patches: | |
| ax.annotate(format(p.get_height(), '.4f'), | |
| (p.get_x() + p.get_width() / 2., p.get_height()), | |
| ha = 'center', va = 'center', | |
| xytext = (0, 9), | |
| textcoords = 'offset points', | |
| fontsize=11) | |
| plt.title(f'Comparison of {metric} across Variants', fontsize=15) | |
| plt.ylim(0, 1.1) | |
| plt.tight_layout() | |
| return plt | |
| def plot_accuracy_by_category(data_df, category_col='Organ', title='Accuracy by Medical Category'): | |
| """ | |
| Biểu đồ cột phân nhóm để so sánh độ chính xác giữa các cơ quan hoặc loại câu hỏi. | |
| data_df: DataFrame có cột category_col, 'Model', và 'Correct' (bool). | |
| """ | |
| acc_df = data_df.groupby([category_col, 'Model'])['Correct'].mean().reset_index() | |
| plt.figure(figsize=(12, 6)) | |
| sns.barplot(x=category_col, y='Correct', hue='Model', data=acc_df) | |
| plt.title(title, fontsize=15) | |
| plt.ylabel('Accuracy') | |
| plt.xticks(rotation=45) | |
| plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
| plt.tight_layout() | |
| return plt | |
| def plot_semantic_distribution(model_scores_dict, title='Semantic Score Distribution (LLM-Judge)'): | |
| """ | |
| Vẽ biểu đồ Violin để so sánh phân bổ điểm số ngữ nghĩa giữa các model (ví dụ B2 vs DPO). | |
| model_scores_dict: {'Model A': [scores], 'Model B': [scores]} | |
| """ | |
| data = [] | |
| for model, scores in model_scores_dict.items(): | |
| for s in scores: | |
| data.append({'Model': model, 'Score': s}) | |
| df = pd.DataFrame(data) | |
| plt.figure(figsize=(10, 6)) | |
| sns.violinplot(x='Model', y='Score', data=df, inner="quart", palette="Set3") | |
| plt.title(title, fontsize=15) | |
| plt.ylim(-0.1, 1.1) | |
| plt.tight_layout() | |
| return plt | |
| def plot_latency_vs_accuracy(model_stats, title='Accuracy vs. Latency Trade-off'): | |
| """ | |
| Biểu đồ bong bóng so sánh Tốc độ và Độ chính xác. | |
| model_stats: List of dicts [{'name': 'A1', 'accuracy': 0.8, 'latency': 0.1, 'params': 100M}, ...] | |
| """ | |
| df = pd.DataFrame(model_stats) | |
| plt.figure(figsize=(10, 7)) | |
| scatter = plt.scatter(df['latency'], df['accuracy'], | |
| s=df['params_mb']*10, # Kích thước bong bóng theo số lượng tham số | |
| alpha=0.5, c=np.arange(len(df)), cmap='viridis') | |
| for i, txt in enumerate(df['name']): | |
| plt.annotate(txt, (df['latency'][i], df['accuracy'][i]), fontsize=12) | |
| plt.xlabel('Latency (seconds/sample)', fontsize=12) | |
| plt.ylabel('Accuracy', fontsize=12) | |
| plt.title(title, fontsize=15) | |
| plt.grid(True, linestyle='--', alpha=0.6) | |
| plt.tight_layout() | |
| return plt | |
| def plot_calibration_curve(y_true, y_probs, n_bins=10, title='Calibration Curve (Reliability)'): | |
| """ | |
| Biểu đồ hiệu chuẩn để xem độ tin cậy của xác suất dự đoán. | |
| y_true: nhãn thực tế [0, 1] | |
| y_probs: xác suất dự đoán lớp 1 | |
| """ | |
| from sklearn.calibration import calibration_curve | |
| prob_true, prob_pred = calibration_curve(y_true, y_probs, n_bins=n_bins) | |
| plt.figure(figsize=(8, 8)) | |
| plt.plot(prob_pred, prob_true, "s-", label='Model') | |
| plt.plot([0, 1], [0, 1], "k--", label='Perfectly Calibrated') | |
| plt.ylabel('Fraction of Positives', fontsize=12) | |
| plt.xlabel('Mean Predicted Probability', fontsize=12) | |
| plt.title(title, fontsize=15) | |
| plt.legend(loc="lower right") | |
| plt.grid(True) | |
| plt.tight_layout() | |
| return plt | |
| def plot_performance_vs_length(questions, corrects, title='Accuracy vs. Question Length'): | |
| """ | |
| Biểu đồ xem độ chính xác có giảm khi câu hỏi dài hơn không. | |
| questions: list các câu hỏi. | |
| corrects: list các giá trị bool (đúng/sai). | |
| """ | |
| lengths = [len(q.split()) for q in questions] | |
| df = pd.DataFrame({'Length': lengths, 'Correct': corrects}) | |
| # Chia nhóm độ dài (bins) | |
| df['Length_Group'] = pd.cut(df['Length'], bins=[0, 5, 10, 15, 20, 30, 50], | |
| labels=['1-5', '6-10', '11-15', '16-20', '21-30', '31+']) | |
| acc_by_len = df.groupby('Length_Group')['Correct'].mean().reset_index() | |
| plt.figure(figsize=(10, 6)) | |
| sns.lineplot(x='Length_Group', y='Correct', data=acc_by_len, marker='o', color='red') | |
| plt.title(title, fontsize=15) | |
| plt.ylabel('Accuracy') | |
| plt.xlabel('Question Length (words)') | |
| plt.ylim(0, 1.1) | |
| plt.grid(True, axis='y') | |
| plt.tight_layout() | |
| return plt |