|
|
"""Visualization utilities for the Korean Q&A evaluation system.""" |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
from plotly.subplots import make_subplots |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from typing import Dict, List, Any |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EvaluationVisualizer: |
|
|
"""Create visualizations for evaluation results.""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
plt.style.use('seaborn-v0_8') |
|
|
sns.set_palette("husl") |
|
|
|
|
|
def create_score_histogram(self, results: Dict[str, Any], metric_name: str = "Answer Relevancy") -> go.Figure: |
|
|
""" |
|
|
Create histogram of evaluation scores. |
|
|
|
|
|
Args: |
|
|
results: Evaluation results dictionary |
|
|
metric_name: Name of the metric to visualize |
|
|
|
|
|
Returns: |
|
|
Plotly figure object |
|
|
""" |
|
|
try: |
|
|
|
|
|
scores = [] |
|
|
for result in results.get('detailed_results', []): |
|
|
metrics = result.get('metrics', {}) |
|
|
for metric, data in metrics.items(): |
|
|
|
|
|
if (metric_name.lower() in metric.lower() or |
|
|
metric_name.replace(" ", "").lower() in metric.lower() or |
|
|
"answerrelevancy" in metric.lower()): |
|
|
scores.append(data.get('score', 0)) |
|
|
|
|
|
if not scores: |
|
|
logger.warning(f"No scores found for metric: {metric_name}") |
|
|
return self._create_empty_figure("No data available") |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Histogram( |
|
|
x=scores, |
|
|
nbinsx=20, |
|
|
name=metric_name, |
|
|
marker_color='skyblue', |
|
|
opacity=0.7, |
|
|
hovertemplate='Score Range: %{x}<br>Count: %{y}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
threshold = results.get('threshold', 0.8) |
|
|
fig.add_vline( |
|
|
x=threshold, |
|
|
line_dash="dash", |
|
|
line_color="red", |
|
|
annotation_text=f"Threshold: {threshold}", |
|
|
annotation_position="top right" |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title=f'{metric_name} Score Distribution', |
|
|
xaxis_title='Score', |
|
|
yaxis_title='Frequency', |
|
|
showlegend=False, |
|
|
template='plotly_white', |
|
|
height=400) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating score histogram: {e}") |
|
|
return self._create_empty_figure("Error creating histogram") |
|
|
|
|
|
def create_pass_fail_pie_chart(self, results: Dict[str, Any]) -> go.Figure: |
|
|
"""Create pie chart showing pass/fail distribution.""" |
|
|
try: |
|
|
passed = results.get('passed_cases', 0) |
|
|
failed = results.get('failed_cases', 0) |
|
|
|
|
|
if passed == 0 and failed == 0: |
|
|
return self._create_empty_figure("No evaluation data available") |
|
|
|
|
|
fig = go.Figure(data=[go.Pie( |
|
|
labels=['Passed', 'Failed'], |
|
|
values=[passed, failed], |
|
|
hole=0.3, |
|
|
marker_colors=['#2E8B57', '#DC143C'], |
|
|
hovertemplate='%{label}: %{value} cases<br>%{percent}<extra></extra>' |
|
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
|
title='Pass/Fail Distribution', |
|
|
template='plotly_white', |
|
|
height=400 |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating pie chart: {e}") |
|
|
return self._create_empty_figure("Error creating pie chart") |
|
|
|
|
|
def create_metrics_comparison(self, results: Dict[str, Any]) -> go.Figure: |
|
|
"""Create comparison chart for different metrics.""" |
|
|
try: |
|
|
|
|
|
metrics_data = {} |
|
|
|
|
|
for result in results.get('detailed_results', []): |
|
|
metrics = result.get('metrics', {}) |
|
|
for metric_name, data in metrics.items(): |
|
|
if metric_name not in metrics_data: |
|
|
metrics_data[metric_name] = [] |
|
|
metrics_data[metric_name].append(data.get('score', 0)) |
|
|
|
|
|
if not metrics_data: |
|
|
return self._create_empty_figure("No metrics data available") |
|
|
|
|
|
|
|
|
fig = make_subplots( |
|
|
rows=len(metrics_data), |
|
|
cols=1, |
|
|
subplot_titles=list(metrics_data.keys()), |
|
|
vertical_spacing=0.1 |
|
|
) |
|
|
|
|
|
colors = px.colors.qualitative.Set3 |
|
|
|
|
|
for i, (metric_name, scores) in enumerate(metrics_data.items()): |
|
|
fig.add_trace( |
|
|
go.Histogram( |
|
|
x=scores, |
|
|
name=metric_name, |
|
|
marker_color=colors[i % len(colors)], |
|
|
opacity=0.7, |
|
|
nbinsx=15 |
|
|
), |
|
|
row=i+1, col=1 |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title='Metrics Comparison', |
|
|
template='plotly_white', |
|
|
height=300 * len(metrics_data), |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating metrics comparison: {e}") |
|
|
return self._create_empty_figure("Error creating metrics comparison") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_score_vs_length_scatter(self, results: Dict[str, Any]) -> go.Figure: |
|
|
try: |
|
|
scores = [] |
|
|
input_lengths = [] |
|
|
output_lengths = [] |
|
|
for result in results.get('detailed_results', []): |
|
|
input_text = result.get('input', '') |
|
|
output_text = result.get('actual_output', '') |
|
|
input_lengths.append(len(input_text)) |
|
|
output_lengths.append(len(output_text)) |
|
|
|
|
|
metrics = result.get('metrics', {}) |
|
|
score = 0 |
|
|
for metric_data in metrics.values(): |
|
|
score = metric_data.get('score', 0) |
|
|
break |
|
|
scores.append(score) |
|
|
if not scores: |
|
|
return self._create_empty_figure("No data available for scatter plot") |
|
|
fig = make_subplots( |
|
|
rows=2, cols=1, |
|
|
subplot_titles=['Score vs Input Length', 'Score vs Output Length'] |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=input_lengths, |
|
|
y=scores, |
|
|
mode='markers', |
|
|
name='Input Length', |
|
|
marker=dict(color='blue', opacity=0.6), |
|
|
hovertemplate='Input Length: %{x}<br>Score: %{y:.3f}<extra></extra>' |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=output_lengths, |
|
|
y=scores, |
|
|
mode='markers', |
|
|
name='Output Length', |
|
|
marker=dict(color='red', opacity=0.6), |
|
|
hovertemplate='Output Length: %{x}<br>Score: %{y:.3f}<extra></extra>' |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title='Score vs Text Length Analysis', |
|
|
template='plotly_white', |
|
|
height=400, |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
fig.update_xaxes(title_text="Input Length (characters)", row=1, col=1) |
|
|
fig.update_xaxes(title_text="Output Length (characters)", row=1, col=2) |
|
|
fig.update_yaxes(title_text="Score", row=1, col=1) |
|
|
fig.update_yaxes(title_text="Score", row=1, col=2) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating scatter plot: {e}") |
|
|
return self._create_empty_figure("Error creating scatter plot") |
|
|
|
|
|
def create_summary_stats_table(self, results: Dict[str, Any]) -> go.Figure: |
|
|
"""Create summary statistics table.""" |
|
|
try: |
|
|
stats = [ |
|
|
['Total Cases', results.get('total_cases', 0)], |
|
|
['Passed Cases', results.get('passed_cases', 0)], |
|
|
['Failed Cases', results.get('failed_cases', 0)], |
|
|
['Pass Rate', f"{results.get('pass_rate', 0):.2f}%"], |
|
|
['Average Score', f"{results.get('average_score', 0):.4f}"], |
|
|
['Threshold', results.get('threshold', 0.8)], |
|
|
['Model', results.get('model_name', 'N/A')], |
|
|
['Evaluation Time', results.get('evaluation_timestamp', 'N/A')] |
|
|
] |
|
|
|
|
|
fig = go.Figure(data=[go.Table( |
|
|
header=dict( |
|
|
values=['Metric', 'Value'], |
|
|
fill_color='lightblue', |
|
|
align='left', |
|
|
font=dict(size=14, color='black') |
|
|
), |
|
|
cells=dict( |
|
|
values=list(zip(*stats)), |
|
|
fill_color='white', |
|
|
align='left', |
|
|
font=dict(size=12) |
|
|
) |
|
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
|
title='Evaluation Summary', |
|
|
height=400, |
|
|
margin=dict(l=0, r=0, t=30, b=0) |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating summary table: {e}") |
|
|
return self._create_empty_figure("Error creating summary table") |
|
|
|
|
|
def _create_empty_figure(self, message: str) -> go.Figure: |
|
|
"""Create an empty figure with a message.""" |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text=message, |
|
|
xref="paper", yref="paper", |
|
|
x=0.5, y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=16) |
|
|
) |
|
|
fig.update_layout( |
|
|
template='plotly_white', |
|
|
height=400, |
|
|
xaxis=dict(visible=False), |
|
|
yaxis=dict(visible=False) |
|
|
) |
|
|
return fig |
|
|
|
|
|
def save_matplotlib_plots(self, results: Dict[str, Any], output_dir: str = "plots") -> List[str]: |
|
|
"""Save matplotlib plots to files.""" |
|
|
import os |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
saved_files = [] |
|
|
|
|
|
try: |
|
|
|
|
|
scores = [] |
|
|
for result in results.get('detailed_results', []): |
|
|
metrics = result.get('metrics', {}) |
|
|
for metric_data in metrics.values(): |
|
|
scores.append(metric_data.get('score', 0)) |
|
|
break |
|
|
|
|
|
if scores: |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.hist(scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black') |
|
|
plt.axvline(results.get('threshold', 0.8), color='red', linestyle='--', |
|
|
label=f"Threshold: {results.get('threshold', 0.8)}") |
|
|
plt.xlabel('Score') |
|
|
plt.ylabel('Frequency') |
|
|
plt.title('Score Distribution') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
hist_file = os.path.join(output_dir, 'score_histogram.png') |
|
|
plt.savefig(hist_file, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
saved_files.append(hist_file) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.boxplot(scores, labels=['Scores']) |
|
|
plt.ylabel('Score') |
|
|
plt.title('Score Distribution (Box Plot)') |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
box_file = os.path.join(output_dir, 'score_boxplot.png') |
|
|
plt.savefig(box_file, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
saved_files.append(box_file) |
|
|
|
|
|
logger.info(f"Saved {len(saved_files)} matplotlib plots to {output_dir}") |
|
|
return saved_files |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error saving matplotlib plots: {e}") |
|
|
return [] |