Spaces:
Running
Running
File size: 7,197 Bytes
6d0d847 e9e9e4c 6d0d847 08d39e4 6d0d847 08d39e4 e9e9e4c 6d0d847 508092c 08d39e4 5dd4497 f8a58e0 5dd4497 6d0d847 08d39e4 e9e9e4c 08d39e4 e9e9e4c 6d0d847 08d39e4 75ecf86 08d39e4 75ecf86 08d39e4 75ecf86 08d39e4 0af2238 08d39e4 0af2238 6d0d847 0af2238 863c2e6 08d39e4 e9e9e4c 8f1beb0 6d0d847 e9e9e4c 863c2e6 e9e9e4c 1b7db5c e9e9e4c 08d39e4 e9e9e4c 6d0d847 e9e9e4c 6d0d847 e9e9e4c 6d0d847 1b7db5c 1fa21a0 1b7db5c a12f124 e9e9e4c a12f124 e9e9e4c 6d0d847 e9e9e4c 6d0d847 08d39e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
from streamlit.components.v1 import html
import pandas as pd
from model_information import get_dataframe
info_df = get_dataframe()
def draw(folder_name, category_one, category_two, sort, num_sort, model_size_range):
folder = f"./results/{folder_name}/"
data_path = f'{folder}/{category_one}/{category_two}.csv'
chart_data = pd.read_csv(data_path).dropna(axis='columns').round(3)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span {
max-width: 800px;
font-size: 0.9rem;
background-color: #3C6478 !important; /* Background color for selected items */
color: white; /* Change text color */
back
}
</style>
""", unsafe_allow_html=True)
# remap model names
display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
model2sizes = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Model Size'])}
chart_data['model_show'] = chart_data['Model'].map(display_model_names)
chart_data['model_show'] = chart_data['model_show'].fillna(chart_data['Model'].apply(lambda x: x.replace('_', '-')))
chart_data['model_size'] = chart_data['Model'].map(model2sizes)
chart_data['model_size'] = chart_data['model_size'].fillna('99999')
# How to work on the model size range, filter the ones that are not in the range
if model_size_range != 'All':
if model_size_range == '<10B':
chart_data = chart_data[chart_data['model_size'].astype(float) < 10]
elif model_size_range == '10B-30B':
chart_data = chart_data[(chart_data['model_size'].astype(float) >= 10) & (chart_data['model_size'].astype(float) < 30)]
elif model_size_range == '>30B':
chart_data = chart_data[chart_data['model_size'].astype(float) >= 30]
chart_data.drop(columns=['model_size'], inplace=True)
models = st.multiselect("Please choose the model",
sorted(chart_data['model_show'].tolist()),
default = sorted(chart_data['model_show'].tolist()),
)
# if 'Select All' in st.session_state.models:
# st.session_state.models = chart_data['model_show'].tolist()
chart_data = chart_data[chart_data['model_show'].isin(models)]
if len(chart_data) == 0: return
min_value = round(min(chart_data.iloc[:, 1]) - 0.1*min(chart_data.iloc[:, 1]), 1)
max_value = round(max(chart_data.iloc[:, 1]) + 0.1*max(chart_data.iloc[:, 1]), 1)
display_names = {
'cross_mmlu' : 'Cross-MMLU',
'cross_mmlu_no_prompt' : 'Cross-MMLU-No-Prompt',
'cross_logiqa' : 'Cross-LogiQA',
'cross_logiqa_no_prompt': 'Cross-LogiQA-No-Prompt',
'cross_xquad' : 'Cross-XQUAD',
'cross_xquad_no_prompt' : 'Cross-XQUAD-No-Prompt',
'sg_eval' : 'SG EVAL',
'sg_eval_v1_cleaned' : 'SG EVAL V1 Cleaned',
'sg_eval_v2_mcq' : 'SG EVAL V2 MCQ',
'sg_eval_v2_open' : 'SG EVAL V2 Open Ended',
'us_eval' : 'US EVAL',
'cn_eval' : 'CN EVAL',
'ph_eval' : 'PH EVAL'
}
data_columns = [i for i in chart_data.columns if i not in ['Model', 'model_show']]
'''
Show Table
'''
with st.container():
st.markdown('##### TABLE')
model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}
chart_data['model_link'] = chart_data['model_show'].map(model_link)
chart_data_table = chart_data[['model_show', 'model_link'] + data_columns]
# Format numeric columns to 2 decimal places
chart_data_table[chart_data_table.columns[2]] = chart_data_table[chart_data_table.columns[2]].apply(lambda x: round(float(x), 3) if isinstance(x, (int, float)) else x)
chart_data_table = chart_data_table.sort_values(
by=chart_data_table.columns[2],
ascending=False
).reset_index(drop=True)
styled_df = chart_data_table.style.highlight_max(
subset=[chart_data_table.columns[2]], color='yellow'
)
st.dataframe(
styled_df,
column_config={
'model_show': 'Model',
chart_data_table.columns[1]: {'alignment': 'center'},
"model_link": st.column_config.LinkColumn(
"Model Link",
),
},
hide_index=True,
use_container_width=True
)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
# Initialize a session state variable for toggling the chart visibility
if "show_chart" not in st.session_state:
st.session_state.show_chart = False
# Create a button to toggle visibility
if st.button("Show Chart"):
st.session_state.show_chart = not st.session_state.show_chart
if st.session_state.show_chart:
with st.container():
st.markdown('##### CHART')
if num_sort == 'Ascending': ascend = True
else: ascend = False
chart_data = chart_data.sort_values(by=[sort], ascending=ascend).dropna(axis=0)
options = {
# "title": {"text": f"{display_names[category_two]}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": data_columns},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": chart_data['model_show'].tolist(),
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [{
"name": f"{col}",
"type": "bar",
"data": chart_data[f'{col}'].tolist(),
} for col in data_columns],
}
events = {
"click": "function(params) { return params.value }"
}
value = st_echarts(options=options, events=events, height="500px") |