dimbyTa commited on
Commit
fbcd930
1 Parent(s): 3bcae7b

Adding application

Browse files
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.display import display_app
2
+
3
+ display_app()
public/datasets/models_scores.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.2.0
2
+ attrs==23.2.0
3
+ backports.zoneinfo==0.2.1
4
+ blinker==1.7.0
5
+ cachetools==5.3.3
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.1.1
10
+ cycler==0.12.1
11
+ fonttools==4.49.0
12
+ gitdb==4.0.11
13
+ GitPython==3.1.42
14
+ idna==3.6
15
+ importlib-metadata==7.0.1
16
+ importlib_resources==6.1.2
17
+ Jinja2==3.1.3
18
+ jsonschema==4.21.1
19
+ jsonschema-specifications==2023.12.1
20
+ kiwisolver==1.4.5
21
+ markdown-it-py==3.0.0
22
+ MarkupSafe==2.1.5
23
+ matplotlib==3.7.5
24
+ mdurl==0.1.2
25
+ numpy==1.24.4
26
+ packaging==23.2
27
+ pandas==2.0.3
28
+ pillow==10.2.0
29
+ pkgutil_resolve_name==1.3.10
30
+ plotly==5.19.0
31
+ protobuf==4.25.3
32
+ pyarrow==15.0.0
33
+ pydeck==0.8.1b0
34
+ Pygments==2.17.2
35
+ pyparsing==3.1.1
36
+ python-dateutil==2.9.0.post0
37
+ python-decouple==3.8
38
+ pytz==2024.1
39
+ referencing==0.33.0
40
+ requests==2.31.0
41
+ rich==13.7.1
42
+ rpds-py==0.18.0
43
+ six==1.16.0
44
+ smmap==5.0.1
45
+ streamlit==1.31.1
46
+ streamlit-aggrid==0.3.4.post3
47
+ tenacity==8.2.3
48
+ toml==0.10.2
49
+ toolz==0.12.1
50
+ tornado==6.4
51
+ typing_extensions==4.10.0
52
+ tzdata==2024.1
53
+ tzlocal==5.2
54
+ urllib3==2.2.1
55
+ validators==0.22.0
56
+ watchdog==4.0.0
57
+ zipp==3.17.0
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/__init__.py
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (140 Bytes). View file
 
src/__pycache__/display.cpython-38.pyc ADDED
Binary file (2.25 kB). View file
 
src/__pycache__/load_data.cpython-38.pyc ADDED
Binary file (1.08 kB). View file
 
src/__pycache__/plot.cpython-38.pyc ADDED
Binary file (2.45 kB). View file
 
src/content.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Nothing for now
src/display.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.load_data import load_dataframe, sort_by
3
+ from src.plot import plot_radar_chart_index, plot_radar_chart_name
4
+ from st_aggrid import GridOptionsBuilder, AgGrid
5
+
6
+ def display_app():
7
+ st.markdown("# Open LLM Leaderboard Viz")
8
+ st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
9
+ st.markdown("To select a model, click on the checkbox beside its name.")
10
+
11
+
12
+
13
+ #container = st.container(height = 150)
14
+
15
+ dataframe = load_dataframe()
16
+
17
+ sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns))
18
+ ascending = True
19
+ indexes = None
20
+ if sort_selection is None:
21
+ sort_selection = "model_name"
22
+ ascending = True
23
+ elif sort_selection == "model_name":
24
+ ascending = True
25
+ else:
26
+ ascending = False
27
+ name = st.text_input(label = ":mag: Search by name")
28
+ if name is not None:
29
+ indexes = dataframe["model_name"].str.contains(name)
30
+ if len(indexes) > 0:
31
+ dataframe = dataframe[indexes]
32
+ else:
33
+ dataframe = load_dataframe()
34
+
35
+ dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
36
+ dataframe_display = dataframe.copy()
37
+ dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
38
+ dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
39
+ dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)
40
+
41
+ #Infer basic colDefs from dataframe types
42
+ gb = GridOptionsBuilder.from_dataframe(dataframe_display)
43
+ gb.configure_selection(selection_mode = "single", use_checkbox=True)
44
+ gb.configure_grid_options(domLayout='normal')
45
+ gridOptions = gb.build()
46
+
47
+ column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")
48
+
49
+ with column1:
50
+ #with container:
51
+ #st.dataframe(dataframe_display)
52
+ grid_response = AgGrid(
53
+ dataframe_display,
54
+ gridOptions=gridOptions,
55
+ height=300,
56
+ width='40%'
57
+ )
58
+
59
+ subdata = dataframe.head(1)
60
+ if len(subdata) > 0:
61
+ model_name = subdata["model_name"].values[0]
62
+ else:
63
+ model_name = ""
64
+
65
+ with column2:
66
+ if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
67
+ model_name = grid_response['selected_rows'][0]["model_name"]
68
+ figure = plot_radar_chart_name(dataframe=dataframe, model_name=model_name)
69
+ st.plotly_chart(figure, use_container_width=False)
70
+ else:
71
+ if len(subdata)>0:
72
+ figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
73
+ st.plotly_chart(figure, use_container_width=True)
74
+
75
+ if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
76
+ st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"])
77
+ else:
78
+ st.markdown("**Model name:** %s" % model_name)
79
+
80
+
src/load_data.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def load_dataframe() -> pd.DataFrame:
4
+ """
5
+ Load dataframe from the csv file in public directory
6
+ Returns
7
+ dataframe: a pd.DataFrame of the average scores of the LLMs on each task
8
+ """
9
+
10
+ dataframe = pd.read_csv("public/datasets/models_scores.csv")
11
+ dataframe = dataframe.drop(columns = "Unnamed: 0")
12
+ return dataframe
13
+
14
+ def sort_by(dataframe: pd.DataFrame, column_name: str, ascending:bool = False) -> pd.DataFrame:
15
+ """
16
+ Sort the dataframe by column_name
17
+
18
+ Arguments:
19
+ - dataframe: a pandas dataframe to sort
20
+ - column_name: a string stating the column to sort the dataframe by
21
+ - ascending: a boolean stating to sort in ascending order or not, default to False
22
+
23
+ Returns:
24
+ a sorted dataframe
25
+ """
26
+ return dataframe.sort_values(by = column_name, ascending = ascending )
src/plot.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.load_data import load_dataframe
2
+ import plotly.graph_objects as go
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ # Hugging Face Colors
7
+ fillcolor = "#FFD21E"
8
+ line_color = "#FF9D00"
9
+
10
+ # opacity of the plot
11
+ opacity = 0.75
12
+
13
+ # categories to show radar chart
14
+ categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
15
+
16
+ def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
17
+ """
18
+ plot the index-th row of the dataframe
19
+
20
+ Arguments:
21
+ dataframe: a pandas DataFrame
22
+ index: the index of the row we want to plot
23
+ categories: the list of the metrics
24
+ fillcolor: a string specifying the color to fill the area
25
+ line_color: a string specifying the color of the lines in the graph
26
+ """
27
+ fig = go.Figure()
28
+ data = dataframe.loc[index,categories].to_numpy()*100
29
+ data = data.astype(float)
30
+ # rounding data
31
+ data = data.round(decimals = 2)
32
+
33
+ # add data to close the area of the radar chart
34
+ data = np.append(data, data[0])
35
+ categories_theta = categories.copy()
36
+ categories_theta.append(categories[0])
37
+ model_name = dataframe.loc[index,"model_name"]
38
+ #print("Printing data ", data, " for ", model_name)
39
+
40
+ fig.add_trace(go.Scatterpolar(
41
+ r=data,
42
+ theta=categories_theta,
43
+ fill='toself',
44
+ fillcolor = fillcolor,
45
+ opacity = opacity,
46
+ line=dict(color = line_color),
47
+ name= model_name
48
+ ))
49
+ fig.update_layout(
50
+ polar=dict(
51
+ radialaxis=dict(
52
+ visible=True,
53
+ range=[0, 100.]
54
+ )),
55
+ showlegend=False
56
+ )
57
+
58
+ return fig
59
+
60
+ def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
61
+ """
62
+ plot the results of the model named model_name row of the dataframe
63
+
64
+ Arguments:
65
+ dataframe: a pandas DataFrame
66
+ model_name: a string stating the name of the model
67
+ categories: the list of the metrics
68
+ fillcolor: a string specifying the color to fill the area
69
+ line_color: a string specifying the color of the lines in the graph
70
+ """
71
+ fig = go.Figure()
72
+ data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100
73
+ data = data.astype(float)
74
+ # rounding data
75
+ data = data.round(decimals = 2)
76
+
77
+ # add data to close the area of the radar chart
78
+ data = np.append(data, data[0])
79
+ categories_theta = categories.copy()
80
+ categories_theta.append(categories[0])
81
+ model_name = model_name
82
+ #print("Printing data ", data, " for ", model_name)
83
+
84
+ fig.add_trace(go.Scatterpolar(
85
+ r=data,
86
+ theta=categories_theta,
87
+ fill='toself',
88
+ fillcolor = fillcolor,
89
+ opacity = opacity,
90
+ line=dict(color = line_color),
91
+ name= model_name
92
+ ))
93
+ fig.update_layout(
94
+ polar=dict(
95
+ radialaxis=dict(
96
+ visible=True,
97
+ range=[0, 100.]
98
+ )),
99
+ showlegend=False
100
+ )
101
+
102
+ return fig