File size: 11,124 Bytes
ff0c7de
fbcd930
a7f9f33
1d040cb
ff0c7de
963c6da
1d040cb
fbcd930
 
 
a7f9f33
fbcd930
a7f9f33
 
 
 
 
 
 
adbb181
12f938b
fbcd930
 
a7f9f33
fbcd930
a7f9f33
864cb6d
 
 
adbb181
fbcd930
12f938b
fbcd930
 
 
 
 
 
 
12f938b
a7f9f33
 
 
eeeb78f
a7f9f33
 
 
eeeb78f
a7f9f33
 
 
ff0c7de
 
a7f9f33
adbb181
ff0c7de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcd930
864cb6d
 
a7f9f33
 
fbcd930
 
 
 
 
 
adbb181
fbcd930
 
 
 
 
 
 
 
 
 
 
 
a7f9f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcd930
 
 
 
 
 
 
 
ff0c7de
eeeb78f
 
 
a7f9f33
 
 
ff0c7de
a7f9f33
ff0c7de
a7f9f33
adbb181
ff0c7de
a7f9f33
 
 
 
 
 
 
 
 
fbcd930
 
ff0c7de
 
 
 
 
 
fbcd930
 
a7f9f33
 
1f586be
 
 
 
a7f9f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864cb6d
a7f9f33
59592ea
a7f9f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864cb6d
a7f9f33
fbcd930
a7f9f33
eeeb78f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

from st_aggrid import GridOptionsBuilder, AgGrid
from streamlit_searchbox import st_searchbox
import streamlit as st
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
from .plot import plot_radar_chart_name, plot_radar_chart_rows


def display_app():
    st.markdown("# Open LLM Leaderboard Viz")
    st.markdown("## Some explanations")
    st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
    st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.")
    st.markdown("You can select up to three models using the search boxes and/or the checkboxes.")
    st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes, 
                i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes.
                   Please, search models using the search boxes first, and then use the checkboxes. 
                """)
    st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.") 
    st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
    st.markdown("If your model doesn't show up, please search it by its name.")

    dataframe = load_dataframe()
    categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]

    st.markdown("## Leaderboard")
    sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1)
    d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"]
    d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True)
    number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
    ascending = True
    
    if sort_selection is None:
        sort_selection = "model_name"
        ascending = True
    elif sort_selection == "model_name":
        ascending = True
    else:
        ascending = False

    # Dynamic search boxes
    def search_model(model_name: str):
        model_list = None
        if model_name is not None or model_name != "":
            models = dataframe["model_name"].str.contains(model_name)
            model_list = dataframe["model_name"][models]
        else:
            model_list = dataframe["model_name"]
        return model_list
    
    model_list = []
    
    #Sidebar configurations
    selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1)
    st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
    ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
                                             placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
    
    ordering_metrics = ordering_metrics.replace(" ", "")
    ordering_metrics = ordering_metrics.split(",")

    st.sidebar.markdown("""
                        As a reminder, here are the different metrics:
                        * ARC
                        * GSM8K
                        * TruthfulQA
                        * Winogrande
                        * HellaSwag
                        * MMLU
                        """)
    st.sidebar.markdown("""
                        If there are **typos** in the name of the metrics, or the number of metrics 
                        is **different of six**, there will be no effect on the chart and the 
                        default ordering will be used.
                         """)

    valid_categories = validate_categories(ordering_metrics)
    dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
    if d_type != "all":
        dataframe = dataframe[dataframe["model_dtype"] == d_type]
    dataframe_display = dataframe.copy()  
    dataframe_display = show_dataframe_top(number_of_row,dataframe_display)    
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)

    #Infer basic colDefs from dataframe types
    gb = GridOptionsBuilder.from_dataframe(dataframe_display)
    gb.configure_selection(selection_mode = selection_mode, use_checkbox=True)
    gb.configure_grid_options(domLayout='normal')
    gridOptions = gb.build()

    column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")

    with column1:
        grid_response = AgGrid(
    dataframe_display, 
    gridOptions=gridOptions,
    height=300, 
    width='40%'
    )
    model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None)
    model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None)
    model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None)   
    if model_one is not None:
        row = dataframe[dataframe["model_name"] == model_one]
        row[categories_display] = row[categories_display]*100
        model_list.append(row.to_dict("records")[0])
    if model_two is not None:
        row = dataframe[dataframe["model_name"] == model_two]
        row[categories_display] = row[categories_display]*100
        model_list.append(row.to_dict("records")[0])
    if model_three is not None:
        row = dataframe[dataframe["model_name"] == model_three]
        row[categories_display] = row[categories_display]*100
        model_list.append(row.to_dict("records")[0])
    subdata = dataframe.head(1)
    if len(subdata) > 0:
        model_name = subdata["model_name"].values[0]
    else:
        model_name = ""

    with column2:
        if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
            figure = None
            #grid_response is now a Pandas dataframe, we need to
            # convert to dict in order to merge with the searchboxes' results
            model_list += grid_response['selected_rows'].to_dict("records")
            model_list = model_list[:3]
            model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
            
            if valid_categories:
                figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
            else:    
                figure = plot_radar_chart_rows(rows=model_list)
            st.plotly_chart(figure, use_container_width=False)
            
        elif len(model_list) > 0:
            figure = None
            model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
            
            if valid_categories:
                figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
            else:    
                figure = plot_radar_chart_rows(rows=model_list)
            st.plotly_chart(figure, use_container_width=False)
        else:
            if len(subdata)>0:
                figure = None
                if valid_categories:
                    figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
                else:
                    figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)

                st.plotly_chart(figure, use_container_width=True)

    if len(model_list) > 1:
        n_col = len(model_list) if len(model_list) <=3 else 3
        st.markdown("## Models")
        columns = st.columns(n_col)
        for i in range(n_col):
            with columns[i]:
                st.markdown("**Model name:**   [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"]))
                st.markdown("**Results:**")
                st.markdown(""" 
                                * Average:    %s  
                                * ARC:        %s
                                * GSM8K:      %s
                                * TruthfulQA: %s
                                * Winogrande: %s
                                * HellaSwag:  %s
                                * MMLU:       %s
                            """ % (round(model_list[i]["Average"],2),
                                   round(model_list[i]["ARC"],2),
                                   round(model_list[i]["GSM8K"],2),
                                   round(model_list[i]["TruthfulQA"],2),
                                   round(model_list[i]["Winogrande"],2),
                                   round(model_list[i]["HellaSwag"],2),
                                   round(model_list[i]["MMLU"],2)
                                   ))
                st.markdown("**dtype:** %s" % model_list[i]["model_dtype"])
    elif len(model_list) == 1:
        st.markdown("**Model name:**   [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
        st.markdown("**Results:**")
        st.markdown(""" 
                                * Average:    %s  
                                * ARC:        %s
                                * GSM8K:      %s
                                * TruthfulQA: %s
                                * Winogrande: %s
                                * HellaSwag:  %s
                                * MMLU:       %s
                            """ % (round(model_list[0]["Average"],2),
                                   round(model_list[0]["ARC"],2),
                                   round(model_list[0]["GSM8K"],2),
                                   round(model_list[0]["TruthfulQA"],2),
                                   round(model_list[0]["Winogrande"],2),
                                   round(model_list[0]["HellaSwag"],2),
                                   round(model_list[0]["MMLU"],2)
                                   ))
        st.markdown("**dtype:** %s" % model_list[0]["model_dtype"])
        st.markdown("For more details, hover over the radar chart.")
    else:
        st.markdown("**Model name:**   %s" % model_name)
        st.markdown("For more details, select the first model in the list/leaderboard.")