iWorldBench commited on
Commit
4097ba4
Β·
0 Parent(s):

Initial commit: iWorld-Bench leaderboard with full code and data

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ bench/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .vscode/
7
+ .idea/
8
+ .DS_Store
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """iWorld-Bench Leaderboard - Hugging Face Space"""
2
+
3
+ from typing import Optional, List
4
+ import gradio as gr
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+
8
+ from src.data_loader import DataLoader
9
+ from src.leaderboard import Leaderboard
10
+ from src.plotter import Plotter
11
+ from src.radar_plotter import RadarPlotter
12
+ from src.styling import dataframe_to_html, get_academic_css
13
+ from src.utils import get_metric_choices, clean_metric_names
14
+
15
+ # Initialize components
16
+ data_loader = DataLoader(results_dir="./data")
17
+ leaderboard = Leaderboard(data_loader)
18
+ plotter = Plotter(data_loader)
19
+ radar_plotter = RadarPlotter(data_loader)
20
+
21
+ DEFAULT_METRIC = "Average ⭐"
22
+
23
+ def reload_data():
24
+ msg = data_loader.reload_data()
25
+ if data_loader.df_all is None or data_loader.df_all.empty:
26
+ dummy_fig, ax = plt.subplots(figsize=(6, 3))
27
+ ax.text(0.5, 0.5, msg, ha="center", va="center")
28
+ ax.axis("off")
29
+ placeholder_html = "<div class='placeholder'>No data available</div>"
30
+ return msg, gr.update(choices=["All"], value="All"), \
31
+ gr.update(choices=["All"], value="All"), \
32
+ gr.update(choices=["All"], value="All"), \
33
+ placeholder_html, dummy_fig
34
+
35
+ open_source_choices = data_loader.get_open_source_choices()
36
+ year_choices = data_loader.get_year_choices()
37
+ category_choices = data_loader.get_category_choices()
38
+
39
+ all_metrics_with_markers = [m for m in get_metric_choices() if m != "Average ⭐"]
40
+
41
+ table_df = leaderboard.update_leaderboard(
42
+ metric="Average",
43
+ top_k=25,
44
+ model_filter="",
45
+ open_source_filter="All",
46
+ year_filter="All",
47
+ category_filter="All",
48
+ sort_mode="Auto",
49
+ selected_metrics=clean_metric_names(all_metrics_with_markers),
50
+ )
51
+
52
+ radar_fig = radar_plotter.create_radar_chart()
53
+ html_table = dataframe_to_html(table_df)
54
+
55
+ return msg, \
56
+ gr.update(choices=open_source_choices, value="All"), \
57
+ gr.update(choices=year_choices, value="All"), \
58
+ gr.update(choices=category_choices, value="All"), \
59
+ html_table, radar_fig
60
+
61
+ def update_leaderboard_wrapper(metric, top_k, model_filter, open_source_filter,
62
+ year_filter, category_filter, sort_mode, selected_metrics):
63
+ clean_metric = clean_metric_names([metric])[0]
64
+ clean_selected_metrics = clean_metric_names(selected_metrics)
65
+
66
+ table_df = leaderboard.update_leaderboard(
67
+ clean_metric, top_k, model_filter, open_source_filter,
68
+ year_filter, category_filter, sort_mode, clean_selected_metrics
69
+ )
70
+
71
+ displayed_models = table_df["Model"].tolist() if not table_df.empty else []
72
+ if displayed_models and data_loader.df_all is not None:
73
+ radar_df = data_loader.df_all[data_loader.df_all["Model"].isin(displayed_models)].copy()
74
+ else:
75
+ radar_df = pd.DataFrame()
76
+
77
+ radar_fig = radar_plotter.create_radar_chart(radar_df)
78
+ html_table = dataframe_to_html(table_df)
79
+ return html_table, radar_fig
80
+
81
+ def create_comparison_plot_wrapper(model_filter, open_source_filter, year_filter,
82
+ category_filter, selected_plot_metric, plot_sort_mode):
83
+ clean_metric = clean_metric_names([selected_plot_metric])[0]
84
+ return plotter.create_comparison_plot(model_filter, open_source_filter,
85
+ year_filter, category_filter,
86
+ clean_metric, plot_sort_mode)
87
+
88
+ # Define CSS once (outside the main block to be used in Blocks)
89
+ academic_css = get_academic_css()
90
+
91
+ with gr.Blocks(css=academic_css) as demo:
92
+ gr.Markdown(
93
+ """
94
+ # <span class="emoji">🌍</span> iWorld-Bench Leaderboard
95
+ <span class="subtitle">A Benchmark for Interactive World Models with a Unified Action Generation Framework</span>
96
+
97
+ **[πŸ“„ Paper](https://arxiv.org/abs/xxx) | [πŸ’» Code](https://github.com/xxx/iworld-bench) | [🌐 Website](https://xxx.github.io/iworld-bench)**
98
+ """,
99
+ elem_id="title"
100
+ )
101
+
102
+ status_box = gr.Markdown("Loading results...", elem_id="status")
103
+
104
+ with gr.Row():
105
+ with gr.Column(scale=2):
106
+ metric_choices = get_metric_choices()
107
+ metric_dropdown = gr.Dropdown(
108
+ label="Primary Ranking Metric",
109
+ choices=metric_choices,
110
+ value=DEFAULT_METRIC,
111
+ interactive=True,
112
+ )
113
+ with gr.Column(scale=1):
114
+ sort_mode_radio = gr.Radio(
115
+ label="Sort Order",
116
+ choices=["Auto", "Ascending (low β†’ high)", "Descending (high β†’ low)"],
117
+ value="Auto",
118
+ interactive=True,
119
+ )
120
+ topk_slider = gr.Slider(
121
+ label="Display Top-K Models",
122
+ minimum=3, maximum=50, value=25, step=1,
123
+ interactive=True,
124
+ )
125
+
126
+ with gr.Row():
127
+ metrics_select = gr.CheckboxGroup(
128
+ label="Additional Metrics to Display (πŸ“Š indicates dimension metrics)",
129
+ choices=[m for m in metric_choices if m != "Average ⭐"],
130
+ value=[m for m in metric_choices if m != "Average ⭐"],
131
+ interactive=True,
132
+ )
133
+
134
+ with gr.Row():
135
+ with gr.Column(scale=1):
136
+ model_filter_box = gr.Textbox(
137
+ label="Filter by Model Name",
138
+ placeholder="Enter model name (partial match)",
139
+ interactive=True,
140
+ )
141
+ with gr.Column(scale=1):
142
+ open_source_dropdown = gr.Dropdown(
143
+ label="Filter by Open Source",
144
+ choices=["All"],
145
+ value="All",
146
+ interactive=True,
147
+ )
148
+ with gr.Column(scale=1):
149
+ year_dropdown = gr.Dropdown(
150
+ label="Filter by Year",
151
+ choices=["All"],
152
+ value="All",
153
+ interactive=True,
154
+ )
155
+ with gr.Column(scale=1):
156
+ category_dropdown = gr.Dropdown(
157
+ label="Filter by Category",
158
+ choices=["All"],
159
+ value="All",
160
+ interactive=True,
161
+ )
162
+
163
+ with gr.Row():
164
+ reload_button = gr.Button("πŸ”„ Reload Data", variant="secondary", size="sm")
165
+ update_button = gr.Button("βœ… Update Leaderboard", variant="primary", size="sm")
166
+
167
+ leaderboard_html = gr.HTML(
168
+ label="Leaderboard Table",
169
+ value="<div class='placeholder'>Leaderboard will be displayed here...</div>"
170
+ )
171
+
172
+ with gr.Row():
173
+ radar_plot = gr.Plot(label="Dimension Radar Chart", format="png")
174
+
175
+ with gr.Row():
176
+ with gr.Column(scale=2):
177
+ plot_metric_radio = gr.Radio(
178
+ label="Select Metric for Comparison Plot",
179
+ choices=metric_choices,
180
+ value=DEFAULT_METRIC,
181
+ interactive=True,
182
+ )
183
+ with gr.Column(scale=1):
184
+ plot_sort_radio = gr.Radio(
185
+ label="Plot Sort Order",
186
+ choices=["Ascending (low β†’ high)", "Descending (high β†’ low)"],
187
+ value="Descending (high β†’ low)",
188
+ interactive=True,
189
+ )
190
+ plot_update_button = gr.Button("πŸ“Š Generate Comparison Plot", variant="primary", size="sm")
191
+
192
+ comparison_plot = gr.Plot(label="Model Comparison Visualization", format="png")
193
+
194
+ reload_button.click(
195
+ fn=reload_data,
196
+ inputs=[],
197
+ outputs=[status_box, open_source_dropdown, year_dropdown, category_dropdown, leaderboard_html, radar_plot],
198
+ )
199
+
200
+ update_button.click(
201
+ fn=update_leaderboard_wrapper,
202
+ inputs=[
203
+ metric_dropdown, topk_slider, model_filter_box,
204
+ open_source_dropdown, year_dropdown, category_dropdown,
205
+ sort_mode_radio, metrics_select,
206
+ ],
207
+ outputs=[leaderboard_html, radar_plot],
208
+ )
209
+
210
+ plot_update_button.click(
211
+ fn=create_comparison_plot_wrapper,
212
+ inputs=[
213
+ model_filter_box, open_source_dropdown, year_dropdown, category_dropdown,
214
+ plot_metric_radio, plot_sort_radio,
215
+ ],
216
+ outputs=[comparison_plot],
217
+ )
218
+
219
+ demo.load(
220
+ fn=reload_data,
221
+ inputs=[],
222
+ outputs=[status_box, open_source_dropdown, year_dropdown, category_dropdown, leaderboard_html, radar_plot],
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ demo.launch(
227
+ server_name="0.0.0.0",
228
+ server_port=7860,
229
+ )
data/results.csv ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model,Category,Average,Image Quality,Brightness Consistency,Color Temperature,Sharpness Retention,Motion Smoothness,Trajectory Accuracy,Memory Symmetry,Trajectory Alignment,Year
2
+ NVIDIA Cosmos,Text-Conditioned,0.6275,0.6778,0.6952,0.7170,0.4363,0.9907,0.4955,0.3738,0.6419,2024
3
+ HunyuanVideo-1.5,Text-Conditioned,0.7188,0.7128,0.7027,0.7477,0.5545,0.9908,0.6844,0.6336,0.6449,2024
4
+ WAN 2.2,Text-Conditioned,0.5731,0.5545,0.3886,0.3411,0.3428,0.9557,0.6514,0.4480,0.5703,2024
5
+ CogVideoX-I2V,Text-Conditioned,0.6963,0.6521,0.8988,0.8129,0.7951,0.9938,0.5950,0.6010,0.4084,2024
6
+ YUME 1.5,Text-Conditioned,0.6209,0.6232,0.3810,0.4165,0.4023,0.9765,0.7113,0.5276,0.5988,2024
7
+ Matrix-game 2.0,One-hot,0.5663,0.4851,0.2963,0.2937,0.4149,0.9848,0.7008,0.3311,0.6362,2024
8
+ HY-World 1.5,One-hot,0.7873,0.6675,0.8051,0.7819,0.6634,0.9921,0.7472,0.8481,0.6776,2024
9
+ CameraCtrl,Intrinsics/Extrinsics,0.5762,0.4473,0.3717,0.2511,0.4545,0.9796,0.6778,0.4279,0.6097,2024
10
+ MotionCtrl,Intrinsics/Extrinsics,0.5486,0.4562,0.3980,0.2012,0.4294,0.9735,0.6730,0.3098,0.5932,2024
11
+ CamI2V,Intrinsics/Extrinsics,0.5765,0.5284,0.4343,0.3568,0.4297,0.9861,0.6314,0.3631,0.6038,2024
12
+ RealCam-I2V,Intrinsics/Extrinsics,0.6865,0.6227,0.4130,0.5547,0.6269,0.9860,0.5630,0.7948,0.6668,2024
13
+ videox-fun-Wan,Intrinsics/Extrinsics,0.7474,0.6410,0.5972,0.5473,0.5998,0.9858,0.7172,0.9009,0.6876,2024
14
+ AC3D,Intrinsics/Extrinsics,0.7149,0.4573,0.7307,0.6524,0.5332,0.9919,0.5785,0.9068,0.6250,2024
15
+ ASTRA,Intrinsics/Extrinsics,0.5980,0.5335,0.5091,0.4338,0.5488,0.9799,0.6115,0.4323,0.5518,2024
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ huggingface-hub>=0.20.0
3
+ pandas>=2.0.0
4
+ matplotlib>=3.7.0
5
+ numpy>=1.24.0
6
+ plotly>=5.0.0
src/__init__.py ADDED
File without changes
src/data_loader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from typing import Optional, List
4
+
5
+ class DataLoader:
6
+ def __init__(self, results_dir: str = "./data"):
7
+ self.results_dir = results_dir
8
+ self.df_all: Optional[pd.DataFrame] = None
9
+ self.reload_data()
10
+
11
+ def reload_data(self) -> str:
12
+ csv_path = os.path.join(self.results_dir, "results.csv")
13
+
14
+ if not os.path.exists(csv_path):
15
+ self._create_sample_data(csv_path)
16
+
17
+ try:
18
+ self.df_all = pd.read_csv(csv_path)
19
+ # Ensure numeric columns are float
20
+ numeric_cols = self.df_all.select_dtypes(include=['float64', 'int64']).columns
21
+ for col in numeric_cols:
22
+ self.df_all[col] = pd.to_numeric(self.df_all[col], errors='coerce')
23
+ return f"βœ… Loaded {len(self.df_all)} models from {csv_path}"
24
+ except Exception as e:
25
+ self.df_all = None
26
+ return f"❌ Error loading data: {str(e)}"
27
+
28
+ def _create_sample_data(self, path: str):
29
+ os.makedirs(os.path.dirname(path), exist_ok=True)
30
+ data = {
31
+ "Model": [
32
+ "NVIDIA Cosmos", "HunyuanVideo-1.5", "WAN 2.2", "CogVideoX-I2V", "YUME 1.5",
33
+ "Matrix-game 2.0", "HY-World 1.5",
34
+ "CameraCtrl", "MotionCtrl", "CamI2V", "RealCam-I2V", "videox-fun-Wan", "AC3D", "ASTRA"
35
+ ],
36
+ "Category": [
37
+ "Text-Conditioned", "Text-Conditioned", "Text-Conditioned", "Text-Conditioned", "Text-Conditioned",
38
+ "One-hot", "One-hot",
39
+ "Intrinsics/Extrinsics", "Intrinsics/Extrinsics", "Intrinsics/Extrinsics",
40
+ "Intrinsics/Extrinsics", "Intrinsics/Extrinsics", "Intrinsics/Extrinsics", "Intrinsics/Extrinsics"
41
+ ],
42
+ "Average": [0.6275, 0.7188, 0.5731, 0.6963, 0.6209, 0.5663, 0.7873,
43
+ 0.5762, 0.5486, 0.5765, 0.6865, 0.7474, 0.7149, 0.5980],
44
+ "Image Quality": [0.6778, 0.7128, 0.5545, 0.6521, 0.6232, 0.4851, 0.6675,
45
+ 0.4473, 0.4562, 0.5284, 0.6227, 0.6410, 0.4573, 0.5335],
46
+ "Brightness Consistency": [0.6952, 0.7027, 0.3886, 0.8988, 0.3810, 0.2963, 0.8051,
47
+ 0.3717, 0.3980, 0.4343, 0.4130, 0.5972, 0.7307, 0.5091],
48
+ "Color Temperature": [0.7170, 0.7477, 0.3411, 0.8129, 0.4165, 0.2937, 0.7819,
49
+ 0.2511, 0.2012, 0.3568, 0.5547, 0.5473, 0.6524, 0.4338],
50
+ "Sharpness Retention": [0.4363, 0.5545, 0.3428, 0.7951, 0.4023, 0.4149, 0.6634,
51
+ 0.4545, 0.4294, 0.4297, 0.6269, 0.5998, 0.5332, 0.5488],
52
+ "Motion Smoothness": [0.9907, 0.9908, 0.9557, 0.9938, 0.9765, 0.9848, 0.9921,
53
+ 0.9796, 0.9735, 0.9861, 0.9860, 0.9858, 0.9919, 0.9799],
54
+ "Trajectory Accuracy": [0.4955, 0.6844, 0.6514, 0.5950, 0.7113, 0.7008, 0.7472,
55
+ 0.6778, 0.6730, 0.6314, 0.5630, 0.7172, 0.5785, 0.6115],
56
+ "Memory Symmetry": [0.3738, 0.6336, 0.4480, 0.6010, 0.5276, 0.3311, 0.8481,
57
+ 0.4279, 0.3098, 0.3631, 0.7948, 0.9009, 0.9068, 0.4323],
58
+ "Trajectory Alignment": [0.6419, 0.6449, 0.5703, 0.4084, 0.5988, 0.6362, 0.6776,
59
+ 0.6097, 0.5932, 0.6038, 0.6668, 0.6876, 0.6250, 0.5518],
60
+ "Year": [2024] * 14
61
+ }
62
+ df = pd.DataFrame(data)
63
+ df.to_csv(path, index=False)
64
+ print(f"Created sample data at {path}")
65
+
66
+ def get_open_source_choices(self) -> List[str]:
67
+ if self.df_all is None:
68
+ return ["All"]
69
+ if "Open Source" not in self.df_all.columns:
70
+ return ["All"]
71
+ choices = ["All"] + sorted(self.df_all["Open Source"].dropna().unique().tolist())
72
+ return choices
73
+
74
+ def get_year_choices(self) -> List[str]:
75
+ if self.df_all is None:
76
+ return ["All"]
77
+ if "Year" not in self.df_all.columns:
78
+ return ["All"]
79
+ choices = ["All"] + sorted(self.df_all["Year"].dropna().unique().tolist(), reverse=True)
80
+ return choices
81
+
82
+ def get_category_choices(self) -> List[str]:
83
+ if self.df_all is None:
84
+ return ["All"]
85
+ if "Category" not in self.df_all.columns:
86
+ return ["All"]
87
+ choices = ["All"] + sorted(self.df_all["Category"].dropna().unique().tolist())
88
+ return choices
89
+
90
+ def filter_data(self, model_filter: str = "", open_source_filter: str = "All",
91
+ year_filter: str = "All", category_filter: str = "All") -> pd.DataFrame:
92
+ if self.df_all is None:
93
+ return pd.DataFrame()
94
+ df = self.df_all.copy()
95
+ if model_filter:
96
+ df = df[df["Model"].str.contains(model_filter, case=False, na=False)]
97
+ if open_source_filter != "All" and "Open Source" in df.columns:
98
+ df = df[df["Open Source"] == open_source_filter]
99
+ if year_filter != "All" and "Year" in df.columns:
100
+ df = df[df["Year"] == int(year_filter)]
101
+ if category_filter != "All" and "Category" in df.columns:
102
+ df = df[df["Category"] == category_filter]
103
+ return df
src/leaderboard.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import List, Optional
3
+ from .data_loader import DataLoader
4
+
5
+ class Leaderboard:
6
+ def __init__(self, data_loader: DataLoader):
7
+ self.data_loader = data_loader
8
+
9
+ def update_leaderboard(self, metric: str = "Average", top_k: int = 25,
10
+ model_filter: str = "", open_source_filter: str = "All",
11
+ year_filter: str = "All", category_filter: str = "All",
12
+ sort_mode: str = "Auto",
13
+ selected_metrics: Optional[List[str]] = None) -> pd.DataFrame:
14
+ df = self.data_loader.filter_data(model_filter, open_source_filter,
15
+ year_filter, category_filter)
16
+ if df.empty:
17
+ return pd.DataFrame()
18
+
19
+ if sort_mode == "Auto":
20
+ ascending = False
21
+ elif sort_mode == "Ascending (low β†’ high)":
22
+ ascending = True
23
+ else:
24
+ ascending = False
25
+
26
+ if metric in df.columns:
27
+ df = df.sort_values(by=metric, ascending=ascending)
28
+
29
+ df = df.head(top_k).reset_index(drop=True)
30
+ df.insert(0, "Rank", range(1, len(df) + 1))
31
+
32
+ base_cols = ["Rank", "Model", "Category"]
33
+ if selected_metrics is None:
34
+ selected_metrics = ["Average"]
35
+
36
+ display_cols = base_cols.copy()
37
+ for m in selected_metrics:
38
+ if m in df.columns and m not in display_cols:
39
+ display_cols.append(m)
40
+
41
+ # Add optional link columns if they exist
42
+ link_cols = []
43
+ if "Paper" in df.columns:
44
+ link_cols.append("Paper")
45
+ if "Code" in df.columns:
46
+ link_cols.append("Code")
47
+ display_cols.extend(link_cols)
48
+
49
+ result_df = df[display_cols].copy()
50
+
51
+ # Format numeric values
52
+ for col in result_df.columns:
53
+ if col not in ["Rank", "Model", "Category", "Paper", "Code", "Open Source", "Year"]:
54
+ result_df[col] = result_df[col].apply(
55
+ lambda x: f"{x:.4f}" if pd.notna(x) and isinstance(x, (int, float)) else "-"
56
+ )
57
+
58
+ # Create hyperlinks if columns exist
59
+ if "Paper" in result_df.columns:
60
+ result_df["Paper"] = result_df["Paper"].apply(
61
+ lambda x: f'<a href="{x}" target="_blank">πŸ“„</a>' if pd.notna(x) and x != "-" else "-"
62
+ )
63
+ if "Code" in result_df.columns:
64
+ result_df["Code"] = result_df["Code"].apply(
65
+ lambda x: f'<a href="{x}" target="_blank">πŸ’»</a>' if pd.notna(x) and x != "-" else "-"
66
+ )
67
+
68
+ return result_df
src/plotter.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pandas as pd
3
+ from typing import Optional
4
+ from .data_loader import DataLoader
5
+
6
+ class Plotter:
7
+ def __init__(self, data_loader: DataLoader):
8
+ self.data_loader = data_loader
9
+
10
+ def create_comparison_plot(self, model_filter: str = "",
11
+ open_source_filter: str = "All",
12
+ year_filter: str = "All",
13
+ category_filter: str = "All",
14
+ metric: str = "Average",
15
+ sort_mode: str = "Descending (high β†’ low)") -> plt.Figure:
16
+ df = self.data_loader.filter_data(model_filter, open_source_filter,
17
+ year_filter, category_filter)
18
+ if df.empty or metric not in df.columns:
19
+ fig, ax = plt.subplots(figsize=(10, 6))
20
+ ax.text(0.5, 0.5, "No data available", ha="center", va="center", fontsize=14)
21
+ ax.axis("off")
22
+ return fig
23
+
24
+ ascending = sort_mode.startswith("Ascending")
25
+ df = df.sort_values(by=metric, ascending=ascending)
26
+ if len(df) > 20:
27
+ df = df.head(20)
28
+
29
+ fig, ax = plt.subplots(figsize=(12, max(6, len(df) * 0.4)))
30
+ colors = {
31
+ "Text-Conditioned": "#3b82f6",
32
+ "One-hot": "#10b981",
33
+ "Intrinsics/Extrinsics": "#f59e0b"
34
+ }
35
+ bar_colors = [colors.get(cat, "#6b7280") for cat in df["Category"]]
36
+
37
+ bars = ax.barh(df["Model"], df[metric], color=bar_colors, edgecolor="white", linewidth=0.5)
38
+ for bar, val in zip(bars, df[metric]):
39
+ width = bar.get_width()
40
+ ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
41
+ f"{val:.4f}", ha="left", va="center", fontsize=9)
42
+
43
+ ax.set_xlabel(metric, fontsize=12, fontweight="bold")
44
+ ax.set_title(f"Model Comparison - {metric}", fontsize=14, fontweight="bold", pad=20)
45
+ ax.set_xlim(0, df[metric].max() * 1.15)
46
+ ax.grid(axis="x", alpha=0.3, linestyle="--")
47
+
48
+ from matplotlib.patches import Patch
49
+ legend_elements = [Patch(facecolor=color, label=cat)
50
+ for cat, color in colors.items()
51
+ if cat in df["Category"].values]
52
+ ax.legend(handles=legend_elements, loc="lower right", title="Category")
53
+ plt.tight_layout()
54
+ return fig
src/radar_plotter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ from typing import Optional, List
5
+ from .data_loader import DataLoader
6
+ from .utils import get_dimension_metrics
7
+
8
+ class RadarPlotter:
9
+ def __init__(self, data_loader: DataLoader):
10
+ self.data_loader = data_loader
11
+ self.dimension_metrics = get_dimension_metrics()
12
+
13
+ def create_radar_chart(self, df: Optional[pd.DataFrame] = None,
14
+ models: Optional[List[str]] = None) -> plt.Figure:
15
+ if df is None or df.empty:
16
+ df = self.data_loader.df_all.copy() if self.data_loader.df_all is not None else pd.DataFrame()
17
+ if df.empty:
18
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
19
+ ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes)
20
+ return fig
21
+
22
+ if len(df) > 8:
23
+ df = df.nlargest(8, "Average")
24
+
25
+ dimensions = list(self.dimension_metrics.keys())
26
+ for dim_name, metrics in self.dimension_metrics.items():
27
+ valid_metrics = [m for m in metrics if m in df.columns]
28
+ if valid_metrics:
29
+ df[dim_name] = df[valid_metrics].mean(axis=1)
30
+ else:
31
+ df[dim_name] = 0
32
+
33
+ angles = np.linspace(0, 2 * np.pi, len(dimensions), endpoint=False).tolist()
34
+ angles += angles[:1]
35
+
36
+ fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
37
+ colors = plt.cm.tab10(np.linspace(0, 1, len(df)))
38
+
39
+ for idx, (_, row) in enumerate(df.iterrows()):
40
+ values = [row.get(dim, 0) for dim in dimensions]
41
+ values += values[:1]
42
+ ax.plot(angles, values, 'o-', linewidth=2, label=row["Model"], color=colors[idx])
43
+ ax.fill(angles, values, alpha=0.1, color=colors[idx])
44
+
45
+ ax.set_xticks(angles[:-1])
46
+ ax.set_xticklabels(dimensions, fontsize=11)
47
+ ax.set_ylim(0, 1)
48
+ ax.set_title("Dimension Performance Radar", fontsize=14, fontweight="bold", pad=20)
49
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=9)
50
+ ax.grid(True, linestyle='--', alpha=0.5)
51
+ plt.tight_layout()
52
+ return fig
src/styling.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def get_academic_css() -> str:
4
+ return """
5
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
6
+ :root {
7
+ --primary: #2563eb;
8
+ --primary-dark: #1d4ed8;
9
+ --accent: #06b6d4;
10
+ --text-dark: #1a1a1a;
11
+ --text-gray: #4a4a4a;
12
+ --border: #e5e5e5;
13
+ --bg-light: #f8f9fa;
14
+ --success: #10b981;
15
+ }
16
+ body {
17
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
18
+ }
19
+ .gradio-container {
20
+ max-width: 1400px !important;
21
+ }
22
+ h1 {
23
+ color: var(--text-dark) !important;
24
+ font-weight: 700 !important;
25
+ }
26
+ .subtitle {
27
+ color: var(--text-gray);
28
+ font-size: 1.1rem;
29
+ margin-top: 0.5rem;
30
+ }
31
+ .emoji {
32
+ font-size: 1.5em;
33
+ }
34
+ .leaderboard-table {
35
+ width: 100%;
36
+ border-collapse: collapse;
37
+ font-size: 0.9rem;
38
+ margin-top: 1rem;
39
+ }
40
+ .leaderboard-table th {
41
+ background: linear-gradient(135deg, #64748b 0%, #94a3b8 100%);
42
+ color: white;
43
+ padding: 12px 8px;
44
+ text-align: center;
45
+ font-weight: 600;
46
+ position: sticky;
47
+ top: 0;
48
+ }
49
+ .leaderboard-table td {
50
+ padding: 10px 8px;
51
+ text-align: center;
52
+ border-bottom: 1px solid var(--border);
53
+ }
54
+ .leaderboard-table tr:nth-child(even) {
55
+ background-color: #f8fafc;
56
+ }
57
+ .leaderboard-table tr:hover {
58
+ background-color: #f1f5f9;
59
+ }
60
+ .rank-1 { background: linear-gradient(135deg, #ffd700 0%, #ffed4a 100%) !important; font-weight: bold; }
61
+ .rank-2 { background: linear-gradient(135deg, #c0c0c0 0%, #e5e7eb 100%) !important; font-weight: bold; }
62
+ .rank-3 { background: linear-gradient(135deg, #cd7f32 0%, #fdba74 100%) !important; font-weight: bold; }
63
+ .best-score {
64
+ color: #2c7a7b;
65
+ font-weight: 700;
66
+ background: #e6fffa;
67
+ }
68
+ .category-tag {
69
+ display: inline-block;
70
+ padding: 2px 8px;
71
+ border-radius: 12px;
72
+ font-size: 0.8rem;
73
+ font-weight: 500;
74
+ }
75
+ .cat-text { background: #dbeafe; color: #1e40af; }
76
+ .cat-onehot { background: #d1fae5; color: #065f46; }
77
+ .cat-camera { background: #fef3c7; color: #92400e; }
78
+ .leaderboard-table a {
79
+ color: var(--primary);
80
+ text-decoration: none;
81
+ font-size: 1.2rem;
82
+ }
83
+ .leaderboard-table a:hover {
84
+ opacity: 0.7;
85
+ }
86
+ button.primary {
87
+ background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%) !important;
88
+ }
89
+ .status-success { color: var(--success); }
90
+ .status-error { color: #ef4444; }
91
+ """
92
+
93
+ def dataframe_to_html(df: pd.DataFrame) -> str:
94
+ if df.empty:
95
+ return "<div class='placeholder'>No data available</div>"
96
+
97
+ html = ['<table class="leaderboard-table">']
98
+ html.append("<thead><tr>")
99
+ for col in df.columns:
100
+ html.append(f"<th>{col}</th>")
101
+ html.append("</tr></thead>")
102
+
103
+ html.append("<tbody>")
104
+ for idx, row in df.iterrows():
105
+ rank_class = ""
106
+ if "Rank" in df.columns:
107
+ rank = row["Rank"]
108
+ if rank == 1:
109
+ rank_class = "rank-1"
110
+ elif rank == 2:
111
+ rank_class = "rank-2"
112
+ elif rank == 3:
113
+ rank_class = "rank-3"
114
+
115
+ html.append(f'<tr class="{rank_class}">')
116
+ for col in df.columns:
117
+ val = row[col]
118
+ if col == "Category":
119
+ cat_class = {
120
+ "Text-Conditioned": "cat-text",
121
+ "One-hot": "cat-onehot",
122
+ "Intrinsics/Extrinsics": "cat-camera"
123
+ }.get(val, "")
124
+ val = f'<span class="category-tag {cat_class}">{val}</span>'
125
+ html.append(f"<td>{val}</td>")
126
+ html.append("</tr>")
127
+ html.append("</tbody></table>")
128
+ return "".join(html)
src/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def get_metric_choices():
4
+ """Return list of metric choices with emoji markers."""
5
+ return [
6
+ "Average ⭐",
7
+ "Image Quality πŸ“Š",
8
+ "Brightness Consistency πŸ“Š",
9
+ "Color Temperature πŸ“Š",
10
+ "Sharpness Retention πŸ“Š",
11
+ "Motion Smoothness πŸ“Š",
12
+ "Trajectory Accuracy πŸ“Š",
13
+ "Memory Symmetry πŸ“Š",
14
+ "Trajectory Alignment πŸ“Š",
15
+ ]
16
+
17
+ def clean_metric_names(metrics):
18
+ """Remove emoji markers from metric names."""
19
+ cleaned = []
20
+ for m in metrics:
21
+ clean = m.replace(" ⭐", "").replace(" πŸ“Š", "").strip()
22
+ cleaned.append(clean)
23
+ return cleaned
24
+
25
+ def get_dimension_metrics():
26
+ """Return mapping from dimension to list of metrics (for radar chart)."""
27
+ return {
28
+ "Generation Quality": ["Image Quality", "Brightness Consistency", "Color Temperature", "Sharpness Retention"],
29
+ "Trajectory Following": ["Motion Smoothness", "Trajectory Accuracy"],
30
+ "Memory Ability": ["Memory Symmetry", "Trajectory Alignment"]
31
+ }