|
|
|
from st_aggrid import GridOptionsBuilder, AgGrid |
|
from streamlit_searchbox import st_searchbox |
|
import streamlit as st |
|
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories |
|
from .plot import plot_radar_chart_name, plot_radar_chart_rows |
|
|
|
|
|
def display_app(): |
|
st.markdown("# Open LLM Leaderboard Viz") |
|
st.markdown("## Some explanations") |
|
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") |
|
st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.") |
|
st.markdown("You can select up to three models using the search boxes and/or the checkboxes.") |
|
st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes, |
|
i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes. |
|
Please, search models using the search boxes first, and then use the checkboxes. |
|
""") |
|
st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.") |
|
st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.") |
|
st.markdown("If your model doesn't show up, please search it by its name.") |
|
|
|
dataframe = load_dataframe() |
|
categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"] |
|
|
|
st.markdown("## Leaderboard") |
|
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7) |
|
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) |
|
ascending = True |
|
|
|
if sort_selection is None: |
|
sort_selection = "model_name" |
|
ascending = True |
|
elif sort_selection == "model_name": |
|
ascending = True |
|
else: |
|
ascending = False |
|
|
|
|
|
def search_model(model_name: str): |
|
model_list = None |
|
if model_name is not None: |
|
models = dataframe["model_name"].str.contains(model_name) |
|
model_list = dataframe["model_name"][models] |
|
else: |
|
model_list = [] |
|
return model_list |
|
|
|
model_list = [] |
|
|
|
|
|
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1) |
|
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.") |
|
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.", |
|
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU") |
|
|
|
ordering_metrics = ordering_metrics.replace(" ", "") |
|
ordering_metrics = ordering_metrics.split(",") |
|
|
|
st.sidebar.markdown(""" |
|
As a reminder, here are the different metrics: |
|
* ARC |
|
* GSM8K |
|
* TruthfulQA |
|
* Winogrande |
|
* HellaSwag |
|
* MMLU |
|
""") |
|
st.sidebar.markdown(""" |
|
If there are **typos** in the name of the metrics, or the number of metrics |
|
is **different of six**, there will be no effect on the chart and the |
|
default ordering will be used. |
|
""") |
|
|
|
valid_categories = validate_categories(ordering_metrics) |
|
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) |
|
dataframe_display = dataframe.copy() |
|
dataframe_display = show_dataframe_top(number_of_row,dataframe_display) |
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) |
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 |
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) |
|
|
|
|
|
gb = GridOptionsBuilder.from_dataframe(dataframe_display) |
|
gb.configure_selection(selection_mode = selection_mode, use_checkbox=True) |
|
gb.configure_grid_options(domLayout='normal') |
|
gridOptions = gb.build() |
|
|
|
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") |
|
|
|
with column1: |
|
grid_response = AgGrid( |
|
dataframe_display, |
|
gridOptions=gridOptions, |
|
height=300, |
|
width='40%' |
|
) |
|
model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None) |
|
model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None) |
|
model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None) |
|
if model_one is not None: |
|
row = dataframe[dataframe["model_name"] == model_one] |
|
row[categories_display] = row[categories_display]*100 |
|
model_list.append(row.to_dict("records")[0]) |
|
if model_two is not None: |
|
row = dataframe[dataframe["model_name"] == model_two] |
|
row[categories_display] = row[categories_display]*100 |
|
model_list.append(row.to_dict("records")[0]) |
|
if model_three is not None: |
|
row = dataframe[dataframe["model_name"] == model_three] |
|
row[categories_display] = row[categories_display]*100 |
|
model_list.append(row.to_dict("records")[0]) |
|
|
|
subdata = dataframe.head(1) |
|
if len(subdata) > 0: |
|
model_name = subdata["model_name"].values[0] |
|
else: |
|
model_name = "" |
|
|
|
with column2: |
|
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: |
|
figure = None |
|
model_list += grid_response['selected_rows'] |
|
model_list = model_list[:3] |
|
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) |
|
|
|
if valid_categories: |
|
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) |
|
else: |
|
figure = plot_radar_chart_rows(rows=model_list) |
|
st.plotly_chart(figure, use_container_width=False) |
|
|
|
elif len(model_list) > 0: |
|
figure = None |
|
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) |
|
|
|
if valid_categories: |
|
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) |
|
else: |
|
figure = plot_radar_chart_rows(rows=model_list) |
|
st.plotly_chart(figure, use_container_width=False) |
|
else: |
|
if len(subdata)>0: |
|
figure = None |
|
if valid_categories: |
|
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name) |
|
else: |
|
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) |
|
|
|
st.plotly_chart(figure, use_container_width=True) |
|
|
|
if len(model_list) > 1: |
|
n_col = len(model_list) if len(model_list) <=3 else 3 |
|
st.markdown("## Models") |
|
columns = st.columns(n_col) |
|
for i in range(n_col): |
|
with columns[i]: |
|
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"])) |
|
st.markdown("**Results:**") |
|
st.markdown(""" |
|
* Average: %s |
|
* ARC: %s |
|
* GSM8K: %s |
|
* TruthfulQA: %s |
|
* Winogrande: %s |
|
* HellaSwag: %s |
|
* MMLU: %s |
|
""" % (round(model_list[i]["Average"],2), |
|
round(model_list[i]["ARC"],2), |
|
round(model_list[i]["GSM8K"],2), |
|
round(model_list[i]["TruthfulQA"],2), |
|
round(model_list[i]["Winogrande"],2), |
|
round(model_list[i]["HellaSwag"],2), |
|
round(model_list[i]["MMLU"],2) |
|
)) |
|
elif len(model_list) == 1: |
|
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"])) |
|
st.markdown("**Results:**") |
|
st.markdown(""" |
|
* Average: %s |
|
* ARC: %s |
|
* GSM8K: %s |
|
* TruthfulQA: %s |
|
* Winogrande: %s |
|
* HellaSwag: %s |
|
* MMLU: %s |
|
""" % (round(model_list[0]["Average"],2), |
|
round(model_list[0]["ARC"],2), |
|
round(model_list[0]["GSM8K"],2), |
|
round(model_list[0]["TruthfulQA"],2), |
|
round(model_list[0]["Winogrande"],2), |
|
round(model_list[0]["HellaSwag"],2), |
|
round(model_list[0]["MMLU"],2) |
|
)) |
|
st.markdown("For more details, hover over the radar chart.") |
|
else: |
|
st.markdown("**Model name:** %s" % model_name) |
|
st.markdown("For more details, select the model.") |