from src.load_data import load_dataframe import plotly.graph_objects as go import numpy as np import pandas as pd # Hugging Face Colors fillcolor = "#FFD21E" line_color = "#FF9D00" # opacity of the plot opacity = 0.75 # categories to show radar chart categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"] def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): """ plot the index-th row of the dataframe Arguments: dataframe: a pandas DataFrame index: the index of the row we want to plot categories: the list of the metrics fillcolor: a string specifying the color to fill the area line_color: a string specifying the color of the lines in the graph """ fig = go.Figure() data = dataframe.loc[index,categories].to_numpy()*100 data = data.astype(float) # rounding data data = data.round(decimals = 2) # add data to close the area of the radar chart data = np.append(data, data[0]) categories_theta = categories.copy() categories_theta.append(categories[0]) model_name = dataframe.loc[index,"model_name"] #print("Printing data ", data, " for ", model_name) fig.add_trace(go.Scatterpolar( r=data, theta=categories_theta, fill='toself', fillcolor = fillcolor, opacity = opacity, line=dict(color = line_color), name= model_name )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100.] )), showlegend=False ) return fig def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): """ plot the results of the model named model_name row of the dataframe Arguments: dataframe: a pandas DataFrame model_name: a string stating the name of the model categories: the list of the metrics fillcolor: a string specifying the color to fill the area line_color: a string specifying the color of the lines in the graph """ fig = go.Figure() data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100 data = data.astype(float) # rounding data data = data.round(decimals = 2) # add data to close the area of the radar chart data = np.append(data, data[0]) categories_theta = categories.copy() categories_theta.append(categories[0]) model_name = model_name #print("Printing data ", data, " for ", model_name) fig.add_trace(go.Scatterpolar( r=data, theta=categories_theta, fill='toself', fillcolor = fillcolor, opacity = opacity, line=dict(color = line_color), name= model_name )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100.] )), showlegend=False ) return fig