import streamlit as st from src.load_data import load_dataframe, sort_by from src.plot import plot_radar_chart_index, plot_radar_chart_name from st_aggrid import GridOptionsBuilder, AgGrid def display_app(): st.markdown("# Open LLM Leaderboard Viz") st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") st.markdown("To select a model, click on the checkbox beside its name.") #container = st.container(height = 150) dataframe = load_dataframe() sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns)) ascending = True indexes = None if sort_selection is None: sort_selection = "model_name" ascending = True elif sort_selection == "model_name": ascending = True else: ascending = False name = st.text_input(label = ":mag: Search by name") if name is not None: indexes = dataframe["model_name"].str.contains(name) if len(indexes) > 0: dataframe = dataframe[indexes] else: dataframe = load_dataframe() dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) dataframe_display = dataframe.copy() dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) #Infer basic colDefs from dataframe types gb = GridOptionsBuilder.from_dataframe(dataframe_display) gb.configure_selection(selection_mode = "single", use_checkbox=True) gb.configure_grid_options(domLayout='normal') gridOptions = gb.build() column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") with column1: #with container: #st.dataframe(dataframe_display) grid_response = AgGrid( dataframe_display, gridOptions=gridOptions, height=300, width='40%' ) subdata = dataframe.head(1) if len(subdata) > 0: model_name = subdata["model_name"].values[0] else: model_name = "" with column2: if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: model_name = grid_response['selected_rows'][0]["model_name"] figure = plot_radar_chart_name(dataframe=dataframe, model_name=model_name) st.plotly_chart(figure, use_container_width=False) else: if len(subdata)>0: figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) st.plotly_chart(figure, use_container_width=True) if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"]) else: st.markdown("**Model name:** %s" % model_name)