Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from typing import Optional, List | |
| from .data_loader import DataLoader | |
| from .utils import get_metric_choices, clean_metric_names | |
| class RadarPlotter: | |
| def __init__(self, data_loader: DataLoader): | |
| self.data_loader = data_loader | |
| # 获取所有具体指标(排除Average) | |
| all_metrics_with_markers = get_metric_choices() | |
| self.metrics = clean_metric_names([m for m in all_metrics_with_markers if m != "Average ⭐"]) | |
| def create_radar_chart(self, df: Optional[pd.DataFrame] = None, | |
| models: Optional[List[str]] = None) -> plt.Figure: | |
| if df is None or df.empty: | |
| df = self.data_loader.df_all.copy() if self.data_loader.df_all is not None else pd.DataFrame() | |
| if df.empty: | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) | |
| ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes) | |
| return fig | |
| # 限制显示的模型数量 | |
| if len(df) > 6: | |
| df = df.nlargest(6, "Average") | |
| # 使用的指标(轴) | |
| metrics = self.metrics | |
| valid_metrics = [m for m in metrics if m in df.columns] | |
| if not valid_metrics: | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) | |
| ax.text(0.5, 0.5, "No valid metrics", ha="center", va="center", transform=ax.transAxes) | |
| return fig | |
| # 角度 | |
| angles = np.linspace(0, 2 * np.pi, len(valid_metrics), endpoint=False).tolist() | |
| angles += angles[:1] | |
| # 创建图形 | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) | |
| # 颜色 | |
| colors = plt.cm.tab10(np.linspace(0, 1, len(df))) | |
| for idx, (_, row) in enumerate(df.iterrows()): | |
| values = [row.get(m, 0) for m in valid_metrics] | |
| values += values[:1] | |
| ax.plot(angles, values, 'o-', linewidth=2, label=row["Model"], color=colors[idx]) | |
| ax.fill(angles, values, alpha=0.1, color=colors[idx]) | |
| # 标签 | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(valid_metrics, fontsize=8) | |
| ax.set_ylim(0, 1) | |
| ax.set_title("Performance Radar (8 metrics)", fontsize=12, fontweight="bold", pad=20) | |
| ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=8) | |
| ax.grid(True, linestyle='--', alpha=0.5) | |
| plt.tight_layout() | |
| return fig |