Medical-VQA / src /utils /evaluation_viz.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
def plot_confusion_matrix(y_true, y_pred, classes, title='Confusion Matrix', cmap=plt.cm.Blues):
"""
Vẽ Confusion Matrix chuyên nghiệp cho các câu hỏi Closed-ended (Yes/No).
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap=cmap,
xticklabels=classes, yticklabels=classes)
plt.title(title, fontsize=15)
plt.ylabel('Ground Truth', fontsize=12)
plt.xlabel('Predicted', fontsize=12)
plt.tight_layout()
return plt
def plot_radar_chart(model_names, metrics_data, categories, title='Model Comparison (All Variants)'):
"""
Vẽ biểu đồ Radar để so sánh 5 biến thể trên nhiều tiêu chí (Accuracy, BLEU, ROUGE, BERTScore).
metrics_data: List of lists, mỗi list là chỉ số của 1 model.
"""
N = len(categories)
angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
for i, model_name in enumerate(model_names):
values = metrics_data[i]
values += values[:1]
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model_name)
ax.fill(angles, values, alpha=0.1)
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
plt.xticks(angles[:-1], categories, fontsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
plt.title(title, size=20, y=1.1)
return plt
def plot_training_history(history, title='Training History'):
"""
Vẽ đồ thị Loss và Accuracy trong quá trình huấn luyện.
history: dict có keys 'train_loss', 'val_acc', v.v.
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Loss plot
ax1.plot(history['train_loss'], label='Train Loss')
if 'val_loss' in history:
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_title('Loss Evolution')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)
# Accuracy plot
ax2.plot(history['val_acc'], label='Val Accuracy', color='green')
ax2.set_title('Accuracy Evolution')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True)
plt.suptitle(title, fontsize=16)
plt.tight_layout()
return plt
def plot_benchmark_comparison(results_df, metric='Accuracy'):
"""
Biểu đồ cột so sánh một chỉ số cụ thể giữa các mô hình.
results_df: DataFrame có cột 'Model' và các chỉ số.
"""
plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")
ax = sns.barplot(x='Model', y=metric, data=results_df, palette='viridis')
for p in ax.patches:
ax.annotate(format(p.get_height(), '.4f'),
(p.get_x() + p.get_width() / 2., p.get_height()),
ha = 'center', va = 'center',
xytext = (0, 9),
textcoords = 'offset points',
fontsize=11)
plt.title(f'Comparison of {metric} across Variants', fontsize=15)
plt.ylim(0, 1.1)
plt.tight_layout()
return plt
def plot_accuracy_by_category(data_df, category_col='Organ', title='Accuracy by Medical Category'):
"""
Biểu đồ cột phân nhóm để so sánh độ chính xác giữa các cơ quan hoặc loại câu hỏi.
data_df: DataFrame có cột category_col, 'Model', và 'Correct' (bool).
"""
acc_df = data_df.groupby([category_col, 'Model'])['Correct'].mean().reset_index()
plt.figure(figsize=(12, 6))
sns.barplot(x=category_col, y='Correct', hue='Model', data=acc_df)
plt.title(title, fontsize=15)
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
return plt
def plot_semantic_distribution(model_scores_dict, title='Semantic Score Distribution (LLM-Judge)'):
"""
Vẽ biểu đồ Violin để so sánh phân bổ điểm số ngữ nghĩa giữa các model (ví dụ B2 vs DPO).
model_scores_dict: {'Model A': [scores], 'Model B': [scores]}
"""
data = []
for model, scores in model_scores_dict.items():
for s in scores:
data.append({'Model': model, 'Score': s})
df = pd.DataFrame(data)
plt.figure(figsize=(10, 6))
sns.violinplot(x='Model', y='Score', data=df, inner="quart", palette="Set3")
plt.title(title, fontsize=15)
plt.ylim(-0.1, 1.1)
plt.tight_layout()
return plt
def plot_latency_vs_accuracy(model_stats, title='Accuracy vs. Latency Trade-off'):
"""
Biểu đồ bong bóng so sánh Tốc độ và Độ chính xác.
model_stats: List of dicts [{'name': 'A1', 'accuracy': 0.8, 'latency': 0.1, 'params': 100M}, ...]
"""
df = pd.DataFrame(model_stats)
plt.figure(figsize=(10, 7))
scatter = plt.scatter(df['latency'], df['accuracy'],
s=df['params_mb']*10, # Kích thước bong bóng theo số lượng tham số
alpha=0.5, c=np.arange(len(df)), cmap='viridis')
for i, txt in enumerate(df['name']):
plt.annotate(txt, (df['latency'][i], df['accuracy'][i]), fontsize=12)
plt.xlabel('Latency (seconds/sample)', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title(title, fontsize=15)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
return plt
def plot_calibration_curve(y_true, y_probs, n_bins=10, title='Calibration Curve (Reliability)'):
"""
Biểu đồ hiệu chuẩn để xem độ tin cậy của xác suất dự đoán.
y_true: nhãn thực tế [0, 1]
y_probs: xác suất dự đoán lớp 1
"""
from sklearn.calibration import calibration_curve
prob_true, prob_pred = calibration_curve(y_true, y_probs, n_bins=n_bins)
plt.figure(figsize=(8, 8))
plt.plot(prob_pred, prob_true, "s-", label='Model')
plt.plot([0, 1], [0, 1], "k--", label='Perfectly Calibrated')
plt.ylabel('Fraction of Positives', fontsize=12)
plt.xlabel('Mean Predicted Probability', fontsize=12)
plt.title(title, fontsize=15)
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
return plt
def plot_performance_vs_length(questions, corrects, title='Accuracy vs. Question Length'):
"""
Biểu đồ xem độ chính xác có giảm khi câu hỏi dài hơn không.
questions: list các câu hỏi.
corrects: list các giá trị bool (đúng/sai).
"""
lengths = [len(q.split()) for q in questions]
df = pd.DataFrame({'Length': lengths, 'Correct': corrects})
# Chia nhóm độ dài (bins)
df['Length_Group'] = pd.cut(df['Length'], bins=[0, 5, 10, 15, 20, 30, 50],
labels=['1-5', '6-10', '11-15', '16-20', '21-30', '31+'])
acc_by_len = df.groupby('Length_Group')['Correct'].mean().reset_index()
plt.figure(figsize=(10, 6))
sns.lineplot(x='Length_Group', y='Correct', data=acc_by_len, marker='o', color='red')
plt.title(title, fontsize=15)
plt.ylabel('Accuracy')
plt.xlabel('Question Length (words)')
plt.ylim(0, 1.1)
plt.grid(True, axis='y')
plt.tight_layout()
return plt