dimbyTa's picture
adding feature of changing the order of the metrics on the circle of the chart
ff0c7de
raw
history blame
No virus
2.54 kB
import streamlit as st
import pandas as pd
@st.cache_data
def load_dataframe() -> pd.DataFrame:
"""
Load dataframe from the csv file in public directory
Returns
dataframe: a pd.DataFrame of the average scores of the LLMs on each task
"""
dataframe = pd.read_csv("public/datasets/models_scores.csv")
dataframe = dataframe.drop(columns = "Unnamed: 0")
return dataframe
@st.cache_data
def show_dataframe_top(n:int , dataframe: pd.DataFrame) -> pd.DataFrame:
"""
read only the n-th first row
Arguments
-n: an integer telling the number of row
-dataframe: the dataframe to slice
Returns
dataframe: a pd.DataFrame of the average scores of the LLMs on each task
"""
return dataframe.head(n)
@st.cache_data
def sort_by(dataframe: pd.DataFrame, column_name: str, ascending:bool = False) -> pd.DataFrame:
"""
Sort the dataframe by column_name
Arguments:
- dataframe: a pandas dataframe to sort
- column_name: a string stating the column to sort the dataframe by
- ascending: a boolean stating to sort in ascending order or not, default to False
Returns:
a sorted dataframe
"""
return dataframe.sort_values(by = column_name, ascending = ascending )
@st.cache_data
def search_by_name(name: str) -> pd.DataFrame:
"""
Search a model by its name
Arguments:
- name: the name of the model or part of it
Returns:
a pandas Dataframe of every row that contains name
"""
dataframe = load_dataframe()
indexes = dataframe["model_name"].str.contains(name)
return dataframe[indexes]
def validate_categories(categories: list) -> bool:
"""
validate a list of categories to the columns in the dataframe
Arguments:
- categories: a list of categories for the ordering of the columns in the dataframe
This expects a list with six elements that should be (not necessary in order):
- ARC
- GSM8K
- TruthfulQA
- Winogrande
- HellaSwag
- MMLU
Returns
- True if the list has the right number of element and right elements
- False otherwise
"""
valid_categories = False
if len(categories) == 6:
if ("ARC" in categories and "GSM8K" in categories and "TruthfulQA" in categories
and "Winogrande" in categories and "HellaSwag" in categories and "MMLU" in categories):
valid_categories = True
else:
valid_categories = False
else:
valid_categories = False
return valid_categories