A7m0d's picture
Upload folder using huggingface_hub
7dfe46c verified
"""Visualization utilities for the Korean Q&A evaluation system."""
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from typing import Dict, List, Any
import logging
logger = logging.getLogger(__name__)
class EvaluationVisualizer:
"""Create visualizations for evaluation results."""
def __init__(self):
# Set style for matplotlib
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
def create_score_histogram(self, results: Dict[str, Any], metric_name: str = "Answer Relevancy") -> go.Figure:
"""
Create histogram of evaluation scores.
Args:
results: Evaluation results dictionary
metric_name: Name of the metric to visualize
Returns:
Plotly figure object
"""
try:
# Extract scores from detailed results
scores = []
for result in results.get('detailed_results', []):
metrics = result.get('metrics', {})
for metric, data in metrics.items():
# Handle both display names and class names
if (metric_name.lower() in metric.lower() or
metric_name.replace(" ", "").lower() in metric.lower() or
"answerrelevancy" in metric.lower()):
scores.append(data.get('score', 0))
if not scores:
logger.warning(f"No scores found for metric: {metric_name}")
return self._create_empty_figure("No data available")
# Create histogram
fig = go.Figure()
fig.add_trace(go.Histogram(
x=scores,
nbinsx=20,
name=metric_name,
marker_color='skyblue',
opacity=0.7,
hovertemplate='Score Range: %{x}<br>Count: %{y}<extra></extra>'
))
# Add threshold line
threshold = results.get('threshold', 0.8)
fig.add_vline(
x=threshold,
line_dash="dash",
line_color="red",
annotation_text=f"Threshold: {threshold}",
annotation_position="top right"
)
# Update layout
fig.update_layout(
title=f'{metric_name} Score Distribution',
xaxis_title='Score',
yaxis_title='Frequency',
showlegend=False,
template='plotly_white',
height=400)
return fig
except Exception as e:
logger.error(f"Error creating score histogram: {e}")
return self._create_empty_figure("Error creating histogram")
def create_pass_fail_pie_chart(self, results: Dict[str, Any]) -> go.Figure:
"""Create pie chart showing pass/fail distribution."""
try:
passed = results.get('passed_cases', 0)
failed = results.get('failed_cases', 0)
if passed == 0 and failed == 0:
return self._create_empty_figure("No evaluation data available")
fig = go.Figure(data=[go.Pie(
labels=['Passed', 'Failed'],
values=[passed, failed],
hole=0.3,
marker_colors=['#2E8B57', '#DC143C'],
hovertemplate='%{label}: %{value} cases<br>%{percent}<extra></extra>'
)])
fig.update_layout(
title='Pass/Fail Distribution',
template='plotly_white',
height=400
)
return fig
except Exception as e:
logger.error(f"Error creating pie chart: {e}")
return self._create_empty_figure("Error creating pie chart")
def create_metrics_comparison(self, results: Dict[str, Any]) -> go.Figure:
"""Create comparison chart for different metrics."""
try:
# Extract metrics data
metrics_data = {}
for result in results.get('detailed_results', []):
metrics = result.get('metrics', {})
for metric_name, data in metrics.items():
if metric_name not in metrics_data:
metrics_data[metric_name] = []
metrics_data[metric_name].append(data.get('score', 0))
if not metrics_data:
return self._create_empty_figure("No metrics data available")
# Create subplots
fig = make_subplots(
rows=len(metrics_data),
cols=1,
subplot_titles=list(metrics_data.keys()),
vertical_spacing=0.1
)
colors = px.colors.qualitative.Set3
for i, (metric_name, scores) in enumerate(metrics_data.items()):
fig.add_trace(
go.Histogram(
x=scores,
name=metric_name,
marker_color=colors[i % len(colors)],
opacity=0.7,
nbinsx=15
),
row=i+1, col=1
)
fig.update_layout(
title='Metrics Comparison',
template='plotly_white',
height=300 * len(metrics_data),
showlegend=False
)
return fig
except Exception as e:
logger.error(f"Error creating metrics comparison: {e}")
return self._create_empty_figure("Error creating metrics comparison")
def create_score_vs_length_scatter(self, results: Dict[str, Any]) -> go.Figure:
try:
scores = []
input_lengths = []
output_lengths = []
for result in results.get('detailed_results', []):
input_text = result.get('input', '')
output_text = result.get('actual_output', '')
input_lengths.append(len(input_text))
output_lengths.append(len(output_text))
# Get the first available score
metrics = result.get('metrics', {})
score = 0
for metric_data in metrics.values():
score = metric_data.get('score', 0)
break
scores.append(score)
if not scores:
return self._create_empty_figure("No data available for scatter plot")
fig = make_subplots(
rows=2, cols=1,
subplot_titles=['Score vs Input Length', 'Score vs Output Length']
)
# Input length scatter
fig.add_trace(
go.Scatter(
x=input_lengths,
y=scores,
mode='markers',
name='Input Length',
marker=dict(color='blue', opacity=0.6),
hovertemplate='Input Length: %{x}<br>Score: %{y:.3f}<extra></extra>'
),
row=1, col=1
)
# Output length scatter
fig.add_trace(
go.Scatter(
x=output_lengths,
y=scores,
mode='markers',
name='Output Length',
marker=dict(color='red', opacity=0.6),
hovertemplate='Output Length: %{x}<br>Score: %{y:.3f}<extra></extra>'
),
row=2, col=1
)
fig.update_layout(
title='Score vs Text Length Analysis',
template='plotly_white',
height=400,
showlegend=False
)
fig.update_xaxes(title_text="Input Length (characters)", row=1, col=1)
fig.update_xaxes(title_text="Output Length (characters)", row=1, col=2)
fig.update_yaxes(title_text="Score", row=1, col=1)
fig.update_yaxes(title_text="Score", row=1, col=2)
return fig
except Exception as e:
logger.error(f"Error creating scatter plot: {e}")
return self._create_empty_figure("Error creating scatter plot")
def create_summary_stats_table(self, results: Dict[str, Any]) -> go.Figure:
"""Create summary statistics table."""
try:
stats = [
['Total Cases', results.get('total_cases', 0)],
['Passed Cases', results.get('passed_cases', 0)],
['Failed Cases', results.get('failed_cases', 0)],
['Pass Rate', f"{results.get('pass_rate', 0):.2f}%"],
['Average Score', f"{results.get('average_score', 0):.4f}"],
['Threshold', results.get('threshold', 0.8)],
['Model', results.get('model_name', 'N/A')],
['Evaluation Time', results.get('evaluation_timestamp', 'N/A')]
]
fig = go.Figure(data=[go.Table(
header=dict(
values=['Metric', 'Value'],
fill_color='lightblue',
align='left',
font=dict(size=14, color='black')
),
cells=dict(
values=list(zip(*stats)),
fill_color='white',
align='left',
font=dict(size=12)
)
)])
fig.update_layout(
title='Evaluation Summary',
height=400,
margin=dict(l=0, r=0, t=30, b=0)
)
return fig
except Exception as e:
logger.error(f"Error creating summary table: {e}")
return self._create_empty_figure("Error creating summary table")
def _create_empty_figure(self, message: str) -> go.Figure:
"""Create an empty figure with a message."""
fig = go.Figure()
fig.add_annotation(
text=message,
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=16)
)
fig.update_layout(
template='plotly_white',
height=400,
xaxis=dict(visible=False),
yaxis=dict(visible=False)
)
return fig
def save_matplotlib_plots(self, results: Dict[str, Any], output_dir: str = "plots") -> List[str]:
"""Save matplotlib plots to files."""
import os
os.makedirs(output_dir, exist_ok=True)
saved_files = []
try:
# Extract scores
scores = []
for result in results.get('detailed_results', []):
metrics = result.get('metrics', {})
for metric_data in metrics.values():
scores.append(metric_data.get('score', 0))
break
if scores:
# Score histogram
plt.figure(figsize=(10, 6))
plt.hist(scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
plt.axvline(results.get('threshold', 0.8), color='red', linestyle='--',
label=f"Threshold: {results.get('threshold', 0.8)}")
plt.xlabel('Score')
plt.ylabel('Frequency')
plt.title('Score Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
hist_file = os.path.join(output_dir, 'score_histogram.png')
plt.savefig(hist_file, dpi=300, bbox_inches='tight')
plt.close()
saved_files.append(hist_file)
# Box plot
plt.figure(figsize=(8, 6))
plt.boxplot(scores, labels=['Scores'])
plt.ylabel('Score')
plt.title('Score Distribution (Box Plot)')
plt.grid(True, alpha=0.3)
box_file = os.path.join(output_dir, 'score_boxplot.png')
plt.savefig(box_file, dpi=300, bbox_inches='tight')
plt.close()
saved_files.append(box_file)
logger.info(f"Saved {len(saved_files)} matplotlib plots to {output_dir}")
return saved_files
except Exception as e:
logger.error(f"Error saving matplotlib plots: {e}")
return []