import streamlit as st import plotly.graph_objects as go import numpy as np import pandas as pd # Hugging Face Colors fillcolor = "#FFD21E" line_color = "#FF9D00" fill_color_list = [fillcolor, "#F05998", "#40BAF0"] line_color_list = [line_color, "#5E233C", "#194A5E"] # opacity of the plot opacity = 0.75 # categories to show radar chart categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"] # Dataset columns columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K", "MMLU", "Average"] #@st.cache_data def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): """ plot the index-th row of the dataframe Arguments: dataframe: a pandas DataFrame index: the index of the row we want to plot categories: the list of the metrics fillcolor: a string specifying the color to fill the area line_color: a string specifying the color of the lines in the graph """ fig = go.Figure() data = dataframe.loc[index,categories].to_numpy()*100 data = data.astype(float) # rounding data data = data.round(decimals = 2) # add data to close the area of the radar chart data = np.append(data, data[0]) categories_theta = categories.copy() categories_theta.append(categories[0]) model_name = dataframe.loc[index,"model_name"] #print("Printing data ", data, " for ", model_name) fig.add_trace(go.Scatterpolar( r=data, theta=categories_theta, fill='toself', fillcolor = fillcolor, opacity = opacity, line=dict(color = line_color), name= model_name )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100.] )), showlegend=False ) return fig #@st.cache_data def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): """ plot the results of the model named model_name row of the dataframe Arguments: dataframe: a pandas DataFrame model_name: a string stating the name of the model categories: the list of the metrics fillcolor: a string specifying the color to fill the area line_color: a string specifying the color of the lines in the graph """ fig = go.Figure() data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100 data = data.astype(float) # rounding data data = data.round(decimals = 2) # add data to close the area of the radar chart data = np.append(data, data[0]) categories_theta = categories.copy() categories_theta.append(categories[0]) model_name = model_name #print("Printing data ", data, " for ", model_name) fig.add_trace(go.Scatterpolar( r=data, theta=categories_theta, fill='toself', fillcolor = fillcolor, opacity = opacity, line=dict(color = line_color), name= model_name )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100.] )), showlegend=False ) return fig #@st.cache_data def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor_list: str = fill_color_list, line_color_list:str = line_color_list): """ plot the results of the model selected by the checkbox Arguments: rows: an iterable whose elements are dicts with columns as their keys columns: the list of the columns to use categories: the list of the metrics fillcolor: a string specifying the color to fill the area line_color: a string specifying the color of the lines in the graph """ fig = go.Figure() dataset = pd.DataFrame(rows, columns=columns) data = dataset[categories].to_numpy() data = data.astype(float) showLegend = False if len(rows) > 1: showLegend = True # add data to close the area of the radar chart data = np.append(data, data[:,0].reshape((-1,1)), axis=1) categories_theta = categories.copy() categories_theta.append(categories[0]) opacity = 0.75 for i in range(len(dataset)): colors = fillcolor_list[i] fig.add_trace(go.Scatterpolar( r=data[i,:], theta=categories_theta, fill='toself', fillcolor = colors, opacity = opacity, line=dict(color = line_color_list[i]), name= dataset.loc[i,"model_name"] )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100.] )), showlegend=showLegend ) opacity -= .2 return fig