RTEB / app /backend /data_page.py
q275343119's picture
fix - Page Not Fund
80cebd7
raw
history blame
13.3 kB
# -*- coding: utf-8 -*-
# @Date : 2025/2/5 16:26
# @Author : q275343119
# @File : data_page.py
# @Description:
import io
from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode
import streamlit as st
from utils.st_copy_to_clipboard import st_copy_to_clipboard
from streamlit_theme import st_theme
from app.backend.app_init_func import LEADERBOARD_MAP
from app.backend.constant import LEADERBOARD_ICON_MAP, BASE_URL
from app.backend.json_util import compress_msgpack, decompress_msgpack
COLUMNS = ['model_name',
'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity',
'query_instruct', 'corpus_instruct', 'reference'
]
HEADER_STYLE = {'fontSize': '18px'}
CELL_STYLE = {'fontSize': '18px'}
def is_section(group_name):
for k, v in LEADERBOARD_MAP.items():
leaderboard_name = v[0][0]
if group_name == leaderboard_name:
return True
return False
def get_closed_dataset():
data_engine = st.session_state["data_engine"]
closed_list = []
results = data_engine.results
for result in results:
if result.get("is_closed"):
closed_list.append(result.get("dataset_name"))
return closed_list
def convert_df_to_csv(df):
output = io.StringIO()
df.to_csv(output, index=False)
return output.getvalue()
def get_column_state():
"""
get column state from url
"""
query_params = st.query_params.get("grid_state", None)
if query_params:
grid_state = decompress_msgpack(query_params)
st.session_state.grid_state = grid_state
return True
return None
def render_page(group_name):
grid_state = st.session_state.get("grid_state", {})
get_column_state()
# Add theme color and grid styles
st.title("Retrieval Embedding Benchmark (RTEB)")
st.markdown("""
<style>
:root {
--theme-color: rgb(129, 150, 64);
--theme-color-light: rgba(129, 150, 64, 0.2);
}
/* AG Grid specific overrides */
.ag-theme-alpine {
--ag-selected-row-background-color: var(--theme-color-light) !important;
--ag-row-hover-color: var(--theme-color-light) !important;
--ag-selected-tab-color: var(--theme-color) !important;
--ag-range-selection-border-color: var(--theme-color) !important;
--ag-range-selection-background-color: var(--theme-color-light) !important;
}
.ag-row-hover {
background-color: var(--theme-color-light) !important;
}
.ag-row-selected {
background-color: var(--theme-color-light) !important;
}
.ag-row-focus {
background-color: var(--theme-color-light) !important;
}
.ag-cell-focus {
border-color: var(--theme-color) !important;
}
/* Keep existing styles */
.center-text {
text-align: center;
color: var(--theme-color);
}
.center-image {
display: block;
margin-left: auto;
margin-right: auto;
}
h2 {
color: var(--theme-color) !important;
}
.ag-header-cell {
background-color: var(--theme-color) !important;
color: white !important;
}
a {
color: var(--theme-color) !important;
}
a:hover {
color: rgba(129, 150, 64, 0.8) !important;
}
/* Download Button */
button[data-testid="stBaseButton-secondary"] {
float: right;
}
/* Toast On The Top*/
div[data-testid="stToastContainer"] {
position: fixed !important;
z-index: 2147483647 !important;
}
</style>
""", unsafe_allow_html=True)
# logo
# st.markdown('<img src="https://www.voyageai.com/logo.svg" class="center-image" width="200">', unsafe_allow_html=True)
title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize(), "")} {group_name.capitalize()}</h2>'
if is_section(group_name):
title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize() + " Leaderboard", "")} {group_name.capitalize() + " Leaderboard"}</h2>'
# title
st.markdown(title, unsafe_allow_html=True)
data_engine = st.session_state["data_engine"]
df = data_engine.jsons_to_df().copy()
csv = convert_df_to_csv(df)
file_name = f"{group_name.capitalize()} Leaderboard" if is_section(group_name) else group_name.capitalize()
st.download_button(
label="Download CSV",
data=csv,
file_name=f"{file_name}.csv",
mime="text/csv",
icon=":material/download:",
)
# get columns
column_list = []
avg_column = None
if is_section(group_name):
avg_columns = []
for column in df.columns:
if column.startswith("Average"):
avg_columns.insert(0, column)
continue
if "Average" in column:
avg_columns.append(column)
continue
avg_column = avg_columns[0]
column_list.extend(avg_columns)
else:
for column in df.columns:
if column.startswith(group_name.capitalize() + " "):
avg_column = column
column_list.append(avg_column)
dataset_list = []
for dataset_dict in data_engine.datasets:
if dataset_dict["name"] == group_name:
dataset_list = dataset_dict["datasets"]
if not is_section(group_name):
column_list.extend(dataset_list)
closed_list = get_closed_dataset()
close_avg_list = list(set(dataset_list) & set(closed_list))
df["Closed average"] = df[close_avg_list].mean(axis=1).round(2)
column_list.append("Closed average")
open_avg_list = list(set(dataset_list) - set(closed_list))
df["Open average"] = df[open_avg_list].mean(axis=1).round(2)
column_list.append("Open average")
df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False)
# rename avg column name
if not is_section(group_name):
new_column = avg_column.replace(group_name.capitalize(), "").strip()
df.rename(columns={avg_column: new_column}, inplace=True)
column_list.remove(avg_column)
avg_column = new_column
# setting column config
grid_options = {
'columnDefs': [
{
'headerName': 'Model Name',
'field': 'model_name',
'pinned': 'left',
'sortable': False,
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
"tooltipValueGetter": JsCode(
"""function(p) {return p.value}"""
),
"width": 250,
'cellRenderer': JsCode("""class CustomHTML {
init(params) {
const link = params.data.reference;
this.eGui = document.createElement('div');
this.eGui.innerHTML = link ?
`<a href="${link}" class="a-cell" target="_blank">${params.value} </a>` :
params.value;
}
getGui() {
return this.eGui;
}
}"""),
'suppressSizeToFit': True
},
{'headerName': "Overall Score",
'field': avg_column,
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True
},
# Add Open average column definition
{'headerName': 'Open Average',
'field': 'Open average',
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True
},
{'headerName': 'Closed Average',
'field': 'Closed average',
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True
},
{
'headerName': 'Embd Dtype',
'field': 'embd_dtype',
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True,
},
{
'headerName': 'Embd Dim',
'field': 'embd_dim',
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True,
},
{
'headerName': 'Number of Parameters',
'field': 'num_params',
'cellDataType': 'number',
"colId": "num_params",
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
'valueFormatter': JsCode(
"""function(params) {
const num = params.value;
if (num >= 1e9) return (num / 1e9).toFixed(2) + "B";
if (num >= 1e6) return (num / 1e6).toFixed(2) + "M";
if (num >= 1e3) return (num / 1e3).toFixed(2) + "K";
return num;
}"""
),
"width": 120,
# 'suppressSizeToFit': True,
},
{
'headerName': 'Context Length',
'field': 'max_tokens',
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
# 'suppressSizeToFit': True,
},
*[{'headerName': column if "Average" not in column else column.replace("Average", "").strip().capitalize(),
'field': column,
'headerStyle': HEADER_STYLE,
'cellStyle': CELL_STYLE,
"headerTooltip": column if "Average" not in column else column.replace("Average",
"").strip().capitalize()
# 'suppressSizeToFit': True
} for column in column_list if
column not in (avg_column, "Closed average", "Open average")]
],
'defaultColDef': {
'filter': True,
'sortable': True,
'resizable': True,
'headerClass': "multi-line-header",
'autoHeaderHeight': True,
'width': 105
},
"autoSizeStrategy": {
"type": 'fitCellContents',
"colIds": [column for column in column_list if column not in (avg_column, "Closed average", "Open average")]
},
"tooltipShowDelay": 500,
"initialState": grid_state,
}
custom_css = {
# Model Name Cell
".a-cell": {
"display": "inline-block",
"white-space": "nowrap",
"overflow": "hidden",
"text-overflow": "ellipsis",
"width": "100%",
"min-width": "0"
},
# Header
".multi-line-header": {
"text-overflow": "clip",
"overflow": "visible",
"white-space": "normal",
"height": "auto",
"font-family": 'Arial',
"font-size": "14px",
"font-weight": "bold",
"padding": "10px",
"text-align": "left",
}
,
# Filter Options and Input
".ag-theme-streamlit .ag-popup": {
"font-family": 'Arial',
"font-size": "14px",
}
, ".ag-picker-field-display": {
"font-family": 'Arial',
"font-size": "14px",
},
".ag-input-field-input .ag-text-field-input": {
"font-family": 'Arial',
"font-size": "14px",
}
}
grid = AgGrid(
df,
enable_enterprise_modules=False,
gridOptions=grid_options,
allow_unsafe_jscode=True,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW,
theme="streamlit",
custom_css=custom_css,
update_on=["stateUpdated"],
)
@st.dialog("URL")
def share_url():
state = grid.grid_state
if state:
share_link = f'{BASE_URL.replace("_", "-")}{group_name}/?grid_state={compress_msgpack(state)}' if not is_section(
group_name) else f'{BASE_URL.replace("_", "-")}?grid_state={compress_msgpack(state)}'
else:
share_link = f'{BASE_URL.replace("_", "-")}{group_name}'
st.write(share_link)
theme = st_theme()
if theme:
theme = theme.get("base")
else:
theme = "light"
st_copy_to_clipboard(share_link, before_copy_label='📋Push to copy', after_copy_label='✅Text copied!',
theme=theme)
share_btn = st.button("Share this page", icon=":material/share:")
if share_btn:
share_url()