iWorld-Bench / app.py
iWorldBench's picture
Update title links & styling
6e61bd8
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from src.data_loader import DataLoader
from src.leaderboard import Leaderboard
from src.plotter import Plotter
from src.radar_plotter import RadarPlotter
from src.styling import dataframe_to_html, get_academic_css
from src.utils import get_metric_choices, clean_metric_names
data_loader = DataLoader(results_dir="./data")
leaderboard = Leaderboard(data_loader)
plotter = Plotter(data_loader)
radar_plotter = RadarPlotter(data_loader)
DEFAULT_METRIC = "Average ⭐"
TITLE_RESOURCE_LINKS = """
<div class="project-links-bar">
<a class="pl-link pl-project" href="https://iworld-bench.com/" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-globe" aria-hidden="true"></i><em>Project Page</em></a>
<a class="pl-link pl-dataset" href="https://huggingface.co/datasets/EmbodiedCity/iWorld-Bench-Dataset" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-database" aria-hidden="true"></i><em>Dataset</em></a>
<a class="pl-link pl-code" href="https://github.com/EmbodiedCity/iWorld-Bench" target="_blank" rel="noopener noreferrer"><i class="fa-brands fa-github" aria-hidden="true"></i><em>Code</em></a>
<a class="pl-link pl-leaderboard" href="https://huggingface.co/spaces/EmbodiedCity/iWorld-Bench" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-trophy" aria-hidden="true"></i><em>Leaderboard</em></a>
</div>
"""
def reload_data():
msg = data_loader.reload_data()
if data_loader.df_all is None or data_loader.df_all.empty:
dummy_fig, ax = plt.subplots(figsize=(6, 3))
ax.text(0.5, 0.5, msg, ha="center", va="center")
ax.axis("off")
placeholder_html = "<div class='placeholder'>No data available</div>"
# Return empty strings for dropdowns, placeholder, dummy figure
return "", gr.update(choices=["All"], value="All"), placeholder_html, dummy_fig
# Only category filter remains
category_choices = data_loader.get_category_choices()
all_metrics_with_markers = [m for m in get_metric_choices() if m != "Average ⭐"]
# Ensure Average column is always included
selected = ["Average"] + clean_metric_names(all_metrics_with_markers)
table_df = leaderboard.update_leaderboard(
metric="Average",
top_k=25,
model_filter="",
open_source_filter="All",
year_filter="All",
category_filter="All",
sort_mode="Auto",
selected_metrics=selected,
)
html_table = dataframe_to_html(table_df)
radar_fig = radar_plotter.create_radar_chart()
return "", \
gr.update(choices=category_choices, value="All"), \
html_table, radar_fig
def update_leaderboard_wrapper(metric, top_k, model_filter,
category_filter, sort_mode, selected_metrics):
clean_metric = clean_metric_names([metric])[0]
# Ensure Average column is always included
clean_selected = ["Average"] + clean_metric_names(selected_metrics)
table_df = leaderboard.update_leaderboard(
clean_metric, top_k, model_filter,
open_source_filter="All",
year_filter="All",
category_filter=category_filter,
sort_mode=sort_mode,
selected_metrics=clean_selected,
)
html_table = dataframe_to_html(table_df)
displayed_models = table_df["Model"].tolist() if not table_df.empty else []
if displayed_models and data_loader.df_all is not None:
radar_df = data_loader.df_all[data_loader.df_all["Model"].isin(displayed_models)].copy()
else:
radar_df = pd.DataFrame()
radar_fig = radar_plotter.create_radar_chart(radar_df)
return html_table, radar_fig
def create_comparison_plot_wrapper(model_filter, category_filter,
selected_plot_metric, plot_sort_mode):
clean_metric = clean_metric_names([selected_plot_metric])[0]
return plotter.create_comparison_plot(
model_filter,
open_source_filter="All",
year_filter="All",
category_filter=category_filter,
metric=clean_metric,
sort_mode=plot_sort_mode
)
academic_css = get_academic_css()
with gr.Blocks(css=academic_css) as demo:
gr.Markdown(
"""
# <span class="emoji">🌍</span> iWorld-Bench Leaderboard
<span class="subtitle">A Benchmark for Interactive World Models with a Unified Action Generation Framework</span>
""",
elem_id="title",
)
gr.HTML(TITLE_RESOURCE_LINKS)
# Hidden status box
status_box = gr.Markdown(visible=False)
with gr.Row():
with gr.Column(scale=2):
metric_choices = get_metric_choices()
metric_dropdown = gr.Dropdown(
label="Primary Ranking Metric",
choices=metric_choices,
value=DEFAULT_METRIC,
interactive=True,
)
with gr.Column(scale=1):
sort_mode_radio = gr.Radio(
label="Sort Order",
choices=["Auto", "Ascending (low β†’ high)", "Descending (high β†’ low)"],
value="Auto",
interactive=True,
)
topk_slider = gr.Slider(
label="Display Top-K Models",
minimum=3, maximum=50, value=25, step=1,
interactive=True,
)
with gr.Row():
metrics_select = gr.CheckboxGroup(
label="Additional Metrics to Display (πŸ“Š indicates dimension metrics)",
choices=[m for m in metric_choices if m != "Average ⭐"],
value=[m for m in metric_choices if m != "Average ⭐"],
interactive=True,
)
with gr.Row():
with gr.Column(scale=1):
model_filter_box = gr.Textbox(
label="Filter by Model Name",
placeholder="Enter model name (partial match)",
interactive=True,
)
# Removed Open Source and Year filters
with gr.Column(scale=1):
category_dropdown = gr.Dropdown(
label="Filter by Category",
choices=["All"],
value="All",
interactive=True,
)
with gr.Row():
reload_button = gr.Button("πŸ”„ Reload Data", variant="secondary", size="sm")
update_button = gr.Button("βœ… Update Leaderboard", variant="primary", size="sm")
leaderboard_html = gr.HTML(
label="Leaderboard Table",
value="<div class='placeholder'>Leaderboard will be displayed here...</div>"
)
with gr.Row():
with gr.Column(scale=1):
radar_plot = gr.Plot(label="Performance Radar (8 metrics)", format="png")
with gr.Column(scale=1):
plot_metric_radio = gr.Radio(
label="Select Metric for Comparison Plot",
choices=metric_choices,
value=DEFAULT_METRIC,
interactive=True,
)
plot_sort_radio = gr.Radio(
label="Plot Sort Order",
choices=["Ascending (low β†’ high)", "Descending (high β†’ low)"],
value="Descending (high β†’ low)",
interactive=True,
)
plot_update_button = gr.Button("πŸ“Š Generate Comparison Plot", variant="primary", size="sm")
comparison_plot = gr.Plot(label="Model Comparison Visualization", format="png")
# Event bindings – adjusted inputs/outputs
reload_button.click(
fn=reload_data,
inputs=[],
outputs=[status_box, category_dropdown, leaderboard_html, radar_plot],
)
update_button.click(
fn=update_leaderboard_wrapper,
inputs=[
metric_dropdown, topk_slider, model_filter_box,
category_dropdown, sort_mode_radio, metrics_select,
],
outputs=[leaderboard_html, radar_plot],
)
plot_update_button.click(
fn=create_comparison_plot_wrapper,
inputs=[
model_filter_box, category_dropdown,
plot_metric_radio, plot_sort_radio,
],
outputs=[comparison_plot],
)
demo.load(
fn=reload_data,
inputs=[],
outputs=[status_box, category_dropdown, leaderboard_html, radar_plot],
)
if __name__ == "__main__":
import os
# HF Spaces: leave share off (default). Docker / locked-down hosts: set GRADIO_SHARE=true.
demo.launch(
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
share=os.environ.get("GRADIO_SHARE", "false").strip().lower() in ("1", "true", "yes"),
)