Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Date : 2025/2/5 16:26 | |
# @Author : q275343119 | |
# @File : data_page.py | |
# @Description: | |
from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode | |
import streamlit as st | |
from app.backend.app_init_func import LEADERBOARD_MAP | |
COLUMNS = ['model_name', | |
'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity', | |
'query_instruct', 'corpus_instruct', 'reference' | |
] | |
HEADER_STYLE = {'fontSize': '18px'} | |
CELL_STYLE = {'fontSize': '18px'} | |
def is_section(group_name): | |
for k, v in LEADERBOARD_MAP.items(): | |
leaderboard_name = v[0][0] | |
if group_name == leaderboard_name: | |
return True | |
return False | |
def get_closed_dataset(): | |
data_engine = st.session_state["data_engine"] | |
closed_list = [] | |
results = data_engine.results | |
for result in results: | |
if result.get("is_closed"): | |
closed_list.append(result.get("dataset_name")) | |
return closed_list | |
def render_page(group_name): | |
# Add theme color and grid styles | |
st.markdown(""" | |
<style> | |
:root { | |
--theme-color: rgb(129, 150, 64); | |
--theme-color-light: rgba(129, 150, 64, 0.2); | |
} | |
/* AG Grid specific overrides */ | |
.ag-theme-alpine { | |
--ag-selected-row-background-color: var(--theme-color-light) !important; | |
--ag-row-hover-color: var(--theme-color-light) !important; | |
--ag-selected-tab-color: var(--theme-color) !important; | |
--ag-range-selection-border-color: var(--theme-color) !important; | |
--ag-range-selection-background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-hover { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-selected { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-focus { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-cell-focus { | |
border-color: var(--theme-color) !important; | |
} | |
/* Keep existing styles */ | |
.center-text { | |
text-align: center; | |
color: var(--theme-color); | |
} | |
.center-image { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
h2 { | |
color: var(--theme-color) !important; | |
} | |
.ag-header-cell { | |
background-color: var(--theme-color) !important; | |
color: white !important; | |
} | |
a { | |
color: var(--theme-color) !important; | |
} | |
a:hover { | |
color: rgba(129, 150, 64, 0.8) !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# logo | |
# st.markdown('<img src="https://www.voyageai.com/logo.svg" class="center-image" width="200">', unsafe_allow_html=True) | |
# title | |
st.markdown('<h2 class="center-text">Embedding Benchmark For Retrieval</h2>', unsafe_allow_html=True) | |
data_engine = st.session_state["data_engine"] | |
df = data_engine.jsons_to_df()[:] | |
# get columns | |
column_list = [] | |
avg_column = None | |
if is_section(group_name): | |
avg_columns = [] | |
for column in df.columns: | |
if column.startswith("Average"): | |
avg_columns.insert(0, column) | |
continue | |
if "Average" in column: | |
avg_columns.append(column) | |
continue | |
avg_column = avg_columns[0] | |
column_list.extend(avg_columns) | |
else: | |
for column in df.columns: | |
if column.startswith(group_name.capitalize() + " "): | |
avg_column = column | |
new_column = avg_column.replace(group_name.capitalize(), "").strip() | |
df.rename(columns={avg_column: new_column}, inplace=True) | |
column_list.append(new_column) | |
avg_column = new_column | |
dataset_list = [] | |
for dataset_dict in data_engine.datasets: | |
if dataset_dict["name"] == group_name: | |
dataset_list = dataset_dict["datasets"] | |
if not is_section(group_name): | |
column_list.extend(dataset_list) | |
closed_list = get_closed_dataset() | |
close_avg_list = list(set(dataset_list) & set(closed_list)) | |
df["Closed average"] = df[close_avg_list].mean(axis=1) | |
column_list.append("Closed average") | |
# Add Open average to the column list if it's not already there | |
open_avg_col = next((col for col in df.columns if col.startswith("Open average")), None) | |
if open_avg_col and open_avg_col not in column_list: | |
column_list.append(open_avg_col) | |
df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False) | |
# setting column config | |
grid_options = { | |
'columnDefs': [ | |
{ | |
'headerName': 'Model Name', | |
'field': 'model_name', | |
'pinned': 'left', | |
'sortable': False, | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'cellRenderer': JsCode("""class CustomHTML { | |
init(params) { | |
const link = params.data.reference; | |
this.eGui = document.createElement('div'); | |
this.eGui.innerHTML = link ? | |
`<a href="${link}" target="_blank">${params.value}</a>` : | |
params.value; | |
} | |
getGui() { | |
return this.eGui; | |
} | |
}"""), | |
}, | |
{'headerName': avg_column, | |
'field': avg_column, | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True}, | |
# Add Open average column definition | |
{'headerName': open_avg_col if open_avg_col else 'Open average', | |
'field': open_avg_col if open_avg_col else 'Open average', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True}, | |
{'headerName': 'Closed average', | |
'field': 'Closed average', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True}, | |
{ | |
'headerName': 'Data Type', | |
'field': 'embd_dtype', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
{ | |
'headerName': 'Embd Dim', | |
'field': 'embd_dim', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
{ | |
'headerName': 'Model Size (# of Parameters)', | |
'field': 'num_params', | |
'cellDataType': 'number', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
{ | |
'headerName': 'Context Length', | |
'field': 'max_tokens', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
{ | |
'headerName': 'Query Instruction', | |
'field': 'query_instruct', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
{ | |
'headerName': 'Corpus Instruction', | |
'field': 'corpus_instruct', | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True, | |
}, | |
*[{'headerName': column, | |
'field': column, | |
'headerStyle': HEADER_STYLE, | |
'cellStyle': CELL_STYLE, | |
'suppressSizeToFit': True} for column in column_list if column not in (avg_column, "Closed average")] | |
], | |
'defaultColDef': { | |
'filter': True, | |
'sortable': True, | |
'resizable': True | |
}, | |
'autoSizeStrategy': { | |
'type': 'fitCellContents' | |
} | |
} | |
AgGrid( | |
df, | |
enable_enterprise_modules=False, | |
gridOptions=grid_options, | |
allow_unsafe_jscode=True, | |
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS, | |
theme="streamlit", | |
) | |