File size: 4,394 Bytes
35378f6 4f18cc8 35378f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import pandas as pd
import plotly.express as px
from src.assets.text_content import SHORT_NAMES
def plotly_plot(df:pd.DataFrame, LIST:list, ALL:list, NAMES:list):
'''
Takes in a list of models for a plotly plot
Args:
df: A dummy dataframe of latest version
LIST: List of models to plot
ALL: Either [] or ["Show All Models"] - toggle view to plot all models
NAMES: Either [] or ["Show Names"] - toggle view to show model names on plot
Returns:
Fig: plotly figure
'''
# Get list of all models and append short names column to df
list_columns = list(df.columns)
ALL_LIST = list(df[list_columns[0]].unique())
short_names = label_map(ALL_LIST)
list_short_names = list(short_names.values())
df["Short"] = list_short_names
if ALL:
LIST = ALL_LIST
# Filter dataframe based on the provided list of models
df = df[df[list_columns[0]].isin(LIST)]
if NAMES:
fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
color_discrete_map={"category1": "blue", "category2": "red"},
hover_name=list_columns[0], template="plotly_white", text="Short")
fig.update_traces(textposition='top center')
else:
fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
color_discrete_map={"category1": "blue", "category2": "red"},
hover_name=list_columns[0], template="plotly_white")
fig.update_layout(
xaxis_title='% Played',
yaxis_title='Quality Score',
title='Overview of benchmark results',
height=1000
)
fig.update_xaxes(range=[-5, 105])
fig.update_yaxes(range=[-5, 105])
return fig
# ['Model', 'Clemscore', 'All(Played)', 'All(Quality Score)']
def compare_plots(df: pd.DataFrame, LIST1: list, LIST2: list, ALL:list, NAMES:list):
'''
Quality Score v/s % Played plot by selecting models
Args:
df: A dummy dataframe of latest version
LIST1: The list of open source models to show in the plot, updated from frontend
LIST2: The list of commercial models to show in the plot, updated from frontend
ALL: Either [] or ["Show All Models"] - toggle view to plot all models
NAMES: Either [] or ["Show Names"] - toggle view to show model names on plot
Returns:
fig: The plot
'''
# Combine lists for Open source and commercial models
LIST = LIST1 + LIST2
fig = plotly_plot(df, LIST, ALL, NAMES)
return fig
def shorten_model_name(full_name):
# Split the name into parts
parts = full_name.split('-')
# Process the name parts to keep only the parts with digits (model sizes and versions)
short_name_parts = [part for part in parts if any(char.isdigit() for char in part)]
if len(parts) == 1:
short_name = ''.join(full_name[0:min(3, len(full_name))])
else:
# Join the parts to form the short name
short_name = '-'.join(short_name_parts)
# Remove any leading or trailing hyphens
short_name = full_name[0] + '-'+ short_name.strip('-')
return short_name
def label_map(model_list: list) -> dict:
'''
Generate a map from long names to short names, to plot them in frontend graph
Define the short names in src/assets/text_content.py
Args:
model_list: A list of long model names
Returns:
short_name: A dict from long to short name
'''
short_names = {}
for model_name in model_list:
if model_name in SHORT_NAMES:
short_name = SHORT_NAMES[model_name]
else:
short_name = shorten_model_name(model_name)
# Define the short name and indicate both models are same
short_names[model_name] = short_name
return short_names
def split_models(MODEL_LIST: list):
'''
Split the models into open source and commercial
'''
open_models = []
comm_models = []
for model in MODEL_LIST:
if model.startswith(('gpt-', 'claude-', 'command')):
comm_models.append(model)
else:
open_models.append(model)
open_models.sort(key=lambda o: o.upper())
comm_models.sort(key=lambda c: c.upper())
return open_models, comm_models
|