NUPA-Anonymous's picture
update model names
e708577
import dash
from dash import dcc, html, Input, Output, State, ALL
import dash_table
import pandas as pd
from collections import defaultdict
from process import get_data # 假设 get_data 提供必要的数据
# 创建示例数据
data, models, task_checklist, metric_checklist, range_checklist, colNames = get_data()
# DataFrame准备
df = pd.DataFrame(data, index=models).reset_index().rename(columns={"index": "model"})
# 分组任务
intTasks = ["Add", "Sub", "Max", "Max_Hard", "Multiply_Hard", "Multiply_Easy", "Digit_Max",
"Digit_Add", "Get_Digit",
"Length",
"Truediv",
"Floordiv",
"Mod",
"Mod_Easy",
"Count",
"Sig","To_Scient"]
floatTasks = ["Add", "Sub", "Max", "Max_Hard", "Multiply_Hard", "Multiply_Easy",
"Digit_Max",
"Digit_Add",
"Get_Digit",
"Length","To_Scient"]
fractionTasks = [
"Add",
"Add_Easy",
"Sub",
"Max",
"Multiply_Hard",
"Multiply_Easy",
"Truediv",
"To_Float"]
sciTasks = [
"Add",
"Sub",
"Max",
"Max_Hard",
"Multiply_Hard",
"Multiply_Easy",
"To_Float",
]
task_groups = defaultdict(list)
task_groups['Integer'] = intTasks
task_groups['Float'] = floatTasks
task_groups['Fraction'] = fractionTasks
task_groups['Scientificnotation'] = sciTasks
# 初始化 Dash 应用
app = dash.Dash(__name__)
# 布局定义
app.layout = html.Div([
html.H1("NUPA Benchmark Leaderboard", style={'textAlign': 'center', 'fontSize': '36px', 'color': '#1E2A47'}),
html.Div([
html.A("GitHub", href="https://github.com/GraphPKU/number_cookbook", target="_blank", style={'marginRight': '20px', 'fontSize': '16px'}),
html.A("Paper", href="https://arxiv.org/abs/2411.03766", target="_blank", style={'marginRight': '20px', 'fontSize': '16px'}),
], style={'textAlign': 'center', 'marginBottom': '30px'}),
html.Div([
html.Span("Haotong Yang, Yi Hu, Shijia Kang, Zhouchen Lin, Muhan Zhang", style={'fontSize': '16px', 'color': '#555'}),
], style={'textAlign': 'center', 'marginBottom': '30px'}),
# 侧边栏按钮和侧边栏容器
html.Div([
# 侧边栏控制按钮
html.Button("Select Tasks", id="toggle-sidebar", n_clicks=0, style={'marginBottom': '20px', 'fontSize': '16px', 'padding': '10px 20px', 'borderRadius': '5px', 'backgroundColor': '#3f4b71', 'color': 'white'}),
# 侧边栏,默认折叠
html.Div([
html.Label("Select Tasks:", style={'fontWeight': 'bold', 'fontSize': '18px'}),
html.Div(id='task-selector-container', children=[
html.Div([
html.Label(group, style={'fontWeight': 'bold', 'fontSize': '16px', 'color': '#3f4b71'}),
dcc.Checklist(
id={'type': 'task-selector', 'group': group},
options=[{'label': task, 'value': task} for task in tasks],
value=['Add'],
inline=True,
style={'marginBottom': '10px'}
)
], style={'marginBottom': '10px'}) for group, tasks in task_groups.items()
]),
html.Label("Select Metrics:", style={'fontWeight': 'bold', 'fontSize': '18px'}),
dcc.Checklist(
id='metric-selector',
options=metric_checklist,
value=['exact match'], # 默认选择所有指标
inline=True,
style={'marginBottom': '20px'}
),
html.Label("Select Range:", style={'fontWeight': 'bold', 'fontSize': '18px'}),
dcc.Checklist(
id='range-selector',
options=range_checklist,
value=['S & M', 'L & XL'],
inline=True,
style={'marginBottom': '20px'}
),
], id="sidebar", style={
'position': 'fixed',
'top': '0',
'right': '0',
'width': '300px',
'height': '100vh',
'backgroundColor': '#f4f4f4',
'padding': '20px',
'boxShadow': '0px 4px 12px rgba(0, 0, 0, 0.1)',
'zIndex': '100',
'display': 'none', # 默认隐藏
'transition': 'transform 0.3s ease', # 平滑显示过渡
'transform': 'translateX(100%)'
}),
], style={'textAlign': 'center', 'marginBottom': '30px'}),
# 主内容区,显示表格
dash_table.DataTable(
id='leaderboard-table',
data=[],
columns=[],
merge_duplicate_headers=True,
style_table={
'height': '600px', # 增加表格高度
'overflowY': 'auto',
'borderRadius': '10px',
'boxShadow': '0 4px 12px rgba(0, 0, 0, 0.1)',
'width': '90%',
'margin': '0 auto',
'backgroundColor': '#f4f4f4',
},
style_cell={
'textAlign': 'center',
'padding': '12px',
'fontFamily': 'Arial, sans-serif',
'fontSize': '14px',
'border': '1px solid #e2e2e2',
'backgroundColor': '#f9f9f9',
'color': '#333',
},
style_header={
'backgroundColor': '#3f4b71',
'color': 'white',
'fontWeight': 'bold',
'textAlign': 'center',
'padding': '12px 10px',
'borderBottom': '1px solid #333',
'textTransform': 'uppercase',
},
style_data={
'backgroundColor': '#ffffff',
'color': '#333',
'borderBottom': '1px solid #e2e2e2',
},
style_data_conditional=[{
'if': {'column_id': 'average'},
'backgroundColor': '#eaf7f8',
'fontWeight': 'bold'
}],
sort_action="native",
sort_mode="multi",
page_size=10,
# 默认按 'average' 排序
sort_by=[{
'column_id': 'average', # 按 'average' 列排序
'direction': 'desc' # 默认降序
}],
)
])
# 回调函数控制侧边栏的显示和隐藏
@app.callback(
Output("sidebar", "style"),
[Input("toggle-sidebar", "n_clicks")],
[State("sidebar", "style")]
)
def toggle_sidebar(n_clicks, sidebar_style):
if n_clicks % 2 == 0:
sidebar_style['display'] = 'none' # 隐藏侧边栏
sidebar_style['transform'] = 'translateX(100%)'
else:
sidebar_style['display'] = 'block' # 显示侧边栏
sidebar_style['transform'] = 'translateX(0%)'
return sidebar_style
# 回调函数动态更新表格显示的列和每行的平均值
@app.callback(
[Output('leaderboard-table', 'columns'), Output('leaderboard-table', 'data')],
[Input({'type': 'task-selector', 'group': ALL}, 'value'),
Input('metric-selector', 'value'),
Input('range-selector', 'value'),
Input('sidebar', 'style')],
[State({'type': 'task-selector', 'group': ALL}, 'options')]
)
def update_columns_and_data(task_values, selected_metrics, selected_ranges, sidebar_style, task_options):
# 从所有分组的复选框中获取选中的任务
tail = ['Integer', 'Float', 'Fraction', 'Scientificnotation']
selected_tasks = []
for i, group in enumerate(task_values):
for task in group:
selected_tasks.append(task + ' ' + tail[i])
# 初始化列
columns = [{"name": ["", "", "model"], "id": "model"}]
selected_data_columns = []
for task in selected_tasks:
for metric in selected_metrics:
for range in selected_ranges:
col_id = f"{task} - {metric} - {range}"
columns.append({"name": [task, metric, range], "id": col_id})
selected_data_columns.append(col_id)
# 添加平均值列
columns.append({"name": ["", "", "selected average"], "id": "average"})
# 更新数据
df_display = df.copy()
df_display["average"] = df_display[selected_data_columns].mean(axis=1)
df_display["average"] = df_display["average"].apply(lambda x: round(x, 2))
data = df_display.to_dict('records')
return columns, data
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860)