from st_aggrid import GridOptionsBuilder, AgGrid from streamlit_searchbox import st_searchbox import streamlit as st from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories from .plot import plot_radar_chart_name, plot_radar_chart_rows def display_app(): st.markdown("# Open LLM Leaderboard Viz") st.markdown("## Some explanations") st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.") st.markdown("You can select up to three models using the search boxes and/or the checkboxes.") st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes, i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes. Please, search models using the search boxes first, and then use the checkboxes. """) st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.") st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.") st.markdown("If your model doesn't show up, please search it by its name.") dataframe = load_dataframe() categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"] st.markdown("## Leaderboard") sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7) number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) ascending = True if sort_selection is None: sort_selection = "model_name" ascending = True elif sort_selection == "model_name": ascending = True else: ascending = False # Dynamic search boxes def search_model(model_name: str): model_list = None if model_name is not None: models = dataframe["model_name"].str.contains(model_name) model_list = dataframe["model_name"][models] else: model_list = [] return model_list model_list = [] #Sidebar configurations selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1) st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.") ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.", placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU") ordering_metrics = ordering_metrics.replace(" ", "") ordering_metrics = ordering_metrics.split(",") st.sidebar.markdown(""" As a reminder, here are the different metrics: * ARC * GSM8K * TruthfulQA * Winogrande * HellaSwag * MMLU """) st.sidebar.markdown(""" If there are **typos** in the name of the metrics, or the number of metrics is **different of six**, there will be no effect on the chart and the default ordering will be used. """) valid_categories = validate_categories(ordering_metrics) dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) dataframe_display = dataframe.copy() dataframe_display = show_dataframe_top(number_of_row,dataframe_display) dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) #Infer basic colDefs from dataframe types gb = GridOptionsBuilder.from_dataframe(dataframe_display) gb.configure_selection(selection_mode = selection_mode, use_checkbox=True) gb.configure_grid_options(domLayout='normal') gridOptions = gb.build() column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") with column1: grid_response = AgGrid( dataframe_display, gridOptions=gridOptions, height=300, width='40%' ) model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None) model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None) model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None) if model_one is not None: row = dataframe[dataframe["model_name"] == model_one] row[categories_display] = row[categories_display]*100 model_list.append(row.to_dict("records")[0]) if model_two is not None: row = dataframe[dataframe["model_name"] == model_two] row[categories_display] = row[categories_display]*100 model_list.append(row.to_dict("records")[0]) if model_three is not None: row = dataframe[dataframe["model_name"] == model_three] row[categories_display] = row[categories_display]*100 model_list.append(row.to_dict("records")[0]) subdata = dataframe.head(1) if len(subdata) > 0: model_name = subdata["model_name"].values[0] else: model_name = "" with column2: if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: figure = None model_list += grid_response['selected_rows'] model_list = model_list[:3] model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) if valid_categories: figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) else: figure = plot_radar_chart_rows(rows=model_list) st.plotly_chart(figure, use_container_width=False) elif len(model_list) > 0: figure = None model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) if valid_categories: figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) else: figure = plot_radar_chart_rows(rows=model_list) st.plotly_chart(figure, use_container_width=False) else: if len(subdata)>0: figure = None if valid_categories: figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name) else: figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) st.plotly_chart(figure, use_container_width=True) if len(model_list) > 1: n_col = len(model_list) if len(model_list) <=3 else 3 st.markdown("## Models") columns = st.columns(n_col) for i in range(n_col): with columns[i]: st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"])) st.markdown("**Results:**") st.markdown(""" * Average: %s * ARC: %s * GSM8K: %s * TruthfulQA: %s * Winogrande: %s * HellaSwag: %s * MMLU: %s """ % (round(model_list[i]["Average"],2), round(model_list[i]["ARC"],2), round(model_list[i]["GSM8K"],2), round(model_list[i]["TruthfulQA"],2), round(model_list[i]["Winogrande"],2), round(model_list[i]["HellaSwag"],2), round(model_list[i]["MMLU"],2) )) elif len(model_list) == 1: st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[i]["model_name"])) st.markdown("**Results:**") st.markdown(""" * Average: %s * ARC: %s * GSM8K: %s * TruthfulQA: %s * Winogrande: %s * HellaSwag: %s * MMLU: %s """ % (round(model_list[0]["Average"],2), round(model_list[0]["ARC"],2), round(model_list[0]["GSM8K"],2), round(model_list[0]["TruthfulQA"],2), round(model_list[0]["Winogrande"],2), round(model_list[0]["HellaSwag"],2), round(model_list[0]["MMLU"],2) )) st.markdown("For more details, hover over the radar chart.") else: st.markdown("**Model name:** %s" % model_name) st.markdown("For more details, select the model.")