Spaces:
Build error
Build error
Upload 4 files
Browse files- src/trend_utils.py +94 -88
src/trend_utils.py
CHANGED
|
@@ -17,6 +17,7 @@ START_DATE = '2023-06-01'
|
|
| 17 |
COLOUR_OPEN = 'red'
|
| 18 |
COLOUR_COMM = 'blue'
|
| 19 |
|
|
|
|
| 20 |
def get_param_size(params: str) -> float:
|
| 21 |
"""Convert parameter size from string to float.
|
| 22 |
|
|
@@ -41,6 +42,7 @@ def get_param_size(params: str) -> float:
|
|
| 41 |
|
| 42 |
return param_size
|
| 43 |
|
|
|
|
| 44 |
def date_difference(date_str1: str, date_str2: str) -> int:
|
| 45 |
"""Calculate the difference in days between two dates.
|
| 46 |
|
|
@@ -104,8 +106,8 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
|
|
| 104 |
Returns:
|
| 105 |
tuple: Two lists of model names (open and commercial).
|
| 106 |
"""
|
| 107 |
-
open_model_df = result_df[result_df['open_weight']==True]
|
| 108 |
-
comm_model_df = result_df[result_df['open_weight']==False]
|
| 109 |
|
| 110 |
open_model_df = open_model_df.sort_values(by='release_date', ascending=True)
|
| 111 |
comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True)
|
|
@@ -113,6 +115,7 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
|
|
| 113 |
comm_models = populate_list(comm_model_df, comm_dip)
|
| 114 |
return open_models, comm_models
|
| 115 |
|
|
|
|
| 116 |
# Function to interpolate between two colors
|
| 117 |
def interpolate_color(rank_val, start_color):
|
| 118 |
"""
|
|
@@ -122,9 +125,10 @@ def interpolate_color(rank_val, start_color):
|
|
| 122 |
elif start_color == 'blue':
|
| 123 |
hue = 240
|
| 124 |
else:
|
| 125 |
-
raise KeyError(
|
| 126 |
-
|
| 127 |
-
|
|
|
|
| 128 |
value = 70 if rank_val == 1 else 100
|
| 129 |
|
| 130 |
return f"hsv({hue},{saturation},{value})"
|
|
@@ -141,7 +145,8 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
|
|
| 141 |
pd.DataFrame: DataFrame containing processed model data.
|
| 142 |
"""
|
| 143 |
visited = set() # Track models that have been processed
|
| 144 |
-
result_df = pd.DataFrame(
|
|
|
|
| 145 |
|
| 146 |
text_dfs = text_data['dataframes']
|
| 147 |
for i in range(len(text_dfs)):
|
|
@@ -153,7 +158,7 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
|
|
| 153 |
visited.add(model_name)
|
| 154 |
for dict_obj in model_registry_data:
|
| 155 |
if dict_obj["model_name"] == model_name:
|
| 156 |
-
if dict_obj["parameters"] == ""
|
| 157 |
params = "1000B"
|
| 158 |
est_flag = True
|
| 159 |
else:
|
|
@@ -161,17 +166,19 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
|
|
| 161 |
est_flag = False
|
| 162 |
|
| 163 |
param_size = get_param_size(params)
|
| 164 |
-
|
| 165 |
-
new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i],
|
| 166 |
-
'
|
|
|
|
|
|
|
| 167 |
result_df.loc[len(result_df)] = new_data
|
| 168 |
break
|
| 169 |
-
|
| 170 |
return result_df # Return the compiled DataFrame
|
| 171 |
|
| 172 |
|
| 173 |
def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
|
| 174 |
-
benchmark_ticks: dict = {}, benchmark_update
|
| 175 |
"""Generate a scatter plot for the given DataFrame.
|
| 176 |
|
| 177 |
Args:
|
|
@@ -180,7 +187,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 180 |
end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
|
| 181 |
benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
|
| 182 |
benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
|
| 183 |
-
|
| 184 |
Keyword Args:
|
| 185 |
open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
| 186 |
comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
|
@@ -203,7 +210,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 203 |
|
| 204 |
# Filter out data before April 2023/START_DATE
|
| 205 |
df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
|
| 206 |
-
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
| 207 |
models_to_display = open_model_list + comm_model_list
|
| 208 |
|
| 209 |
# Create a column to indicate if the model should be labeled
|
|
@@ -216,11 +223,11 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 216 |
versions = df['version'].unique()
|
| 217 |
version_names = sorted(
|
| 218 |
[ver for ver in versions],
|
| 219 |
-
key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
|
| 220 |
reverse=True
|
| 221 |
-
)
|
| 222 |
|
| 223 |
-
version_names = version_names[:3]
|
| 224 |
df = df[df['version'].isin(tuple(version_names))]
|
| 225 |
|
| 226 |
rank = 2
|
|
@@ -228,7 +235,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 228 |
rank_value = {version_names[0]: 1}
|
| 229 |
for ver in version_names:
|
| 230 |
if ver not in rank_value:
|
| 231 |
-
rank_value[ver] = 1 - (rank-1-(max_rank/15))/(max_rank-1)
|
| 232 |
rank += 1
|
| 233 |
|
| 234 |
df['color_value'] = df.apply(
|
|
@@ -246,36 +253,40 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 246 |
for i in range(len(df)):
|
| 247 |
if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
|
| 248 |
if df.iloc[i]['open_weight']:
|
| 249 |
-
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
|
|
|
|
| 250 |
else:
|
| 251 |
-
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
|
|
|
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
# Create the scatter plot
|
| 257 |
fig = px.scatter(df,
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
| 274 |
|
| 275 |
fig.update_traces(
|
| 276 |
hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model): %{customdata[1]|%Y-%m-%d}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
|
| 277 |
)
|
| 278 |
-
|
| 279 |
# Sort dataframes for line plotting
|
| 280 |
df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
| 281 |
df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
|
@@ -291,80 +302,73 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 291 |
|
| 292 |
## Benchmark Version ticks
|
| 293 |
benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
|
| 294 |
-
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
|
| 295 |
custom_tickvals = list(custom_ticks.keys())
|
| 296 |
|
| 297 |
if mobile_view:
|
| 298 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
| 299 |
one_month = pd.DateOffset(months=1)
|
| 300 |
filtered_custom_tickvals = [
|
| 301 |
-
date for date in custom_tickvals
|
| 302 |
-
if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in
|
|
|
|
| 303 |
]
|
| 304 |
-
|
| 305 |
benchmark_tick_texts = []
|
| 306 |
for i in range(len(benchmark_tickvals)):
|
| 307 |
-
# Alternate <br> for benchmark ticks based on date difference (Eg. v1.6, v1.6.5 too close to each other for MM benchmark)
|
| 308 |
-
# if i == 0:
|
| 309 |
-
# benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
| 310 |
-
# else:
|
| 311 |
-
# date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
|
| 312 |
-
# if date_diff <= 75:
|
| 313 |
-
# benchmark_tick_texts.append(f"<br><br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
| 314 |
-
# else:
|
| 315 |
-
# benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
| 316 |
benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
| 317 |
-
|
| 318 |
fig.update_xaxes(
|
| 319 |
tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
| 320 |
-
ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
|
| 321 |
-
|
| 322 |
tickangle=0,
|
| 323 |
tickfont=dict(size=10)
|
| 324 |
)
|
| 325 |
-
fig.update_yaxes(range=[0,
|
|
|
|
| 326 |
display_mode = 'lines+markers'
|
| 327 |
else:
|
| 328 |
fig.update_xaxes(
|
| 329 |
tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
| 330 |
-
ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
|
| 331 |
-
|
|
|
|
| 332 |
tickangle=0,
|
| 333 |
-
tickfont=dict(size=10)
|
| 334 |
)
|
| 335 |
-
fig.update_yaxes(range=[0, max_clemscore+10])
|
| 336 |
display_mode = 'lines+markers+text'
|
| 337 |
|
| 338 |
-
|
| 339 |
# Add lines connecting the points for open models
|
| 340 |
fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
|
| 347 |
# Add lines connecting the points for commercial models
|
| 348 |
-
fig.add_trace(
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
# Update layout to ensure text labels are visible
|
| 357 |
fig.update_traces(textposition='top center')
|
| 358 |
|
| 359 |
# Update the Legend Position and plot dimensions
|
| 360 |
fig.update_layout(height=height,
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
|
| 369 |
if mobile_view:
|
| 370 |
if mobile_view:
|
|
@@ -379,11 +383,10 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
| 379 |
bgcolor='rgba(255,255,255,0.7)', # semi-transparent white for mobile
|
| 380 |
bordercolor='rgba(0,0,0,0.05)')
|
| 381 |
)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
|
| 385 |
return fig
|
| 386 |
|
|
|
|
| 387 |
def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
|
| 388 |
"""Fetch and generate the final trend plot for all models.
|
| 389 |
|
|
@@ -423,14 +426,15 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
| 423 |
text_result_df = get_trend_data(text_data, model_registry_data)
|
| 424 |
## Get benchmark tickvalues as dates for X-axis
|
| 425 |
for ver in versions:
|
| 426 |
-
if 'multimodal' not in ver['version']:
|
| 427 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
|
| 428 |
if pd.to_datetime(ver['last_updated']) not in benchmark_update:
|
| 429 |
benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
|
| 430 |
else:
|
| 431 |
benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
|
| 432 |
|
| 433 |
-
fig =
|
|
|
|
| 434 |
else:
|
| 435 |
mm_data = get_github_data()['multimodal']
|
| 436 |
result_df = get_trend_data(mm_data, model_registry_data)
|
|
@@ -439,9 +443,11 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
| 439 |
if 'multimodal' in ver['version']:
|
| 440 |
temp_ver = ver['version']
|
| 441 |
temp_ver = temp_ver.replace('_multimodal', '')
|
| 442 |
-
benchmark_ticks[
|
|
|
|
| 443 |
benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
|
| 444 |
|
| 445 |
-
fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'),
|
|
|
|
| 446 |
|
| 447 |
return fig
|
|
|
|
| 17 |
COLOUR_OPEN = 'red'
|
| 18 |
COLOUR_COMM = 'blue'
|
| 19 |
|
| 20 |
+
|
| 21 |
def get_param_size(params: str) -> float:
|
| 22 |
"""Convert parameter size from string to float.
|
| 23 |
|
|
|
|
| 42 |
|
| 43 |
return param_size
|
| 44 |
|
| 45 |
+
|
| 46 |
def date_difference(date_str1: str, date_str2: str) -> int:
|
| 47 |
"""Calculate the difference in days between two dates.
|
| 48 |
|
|
|
|
| 106 |
Returns:
|
| 107 |
tuple: Two lists of model names (open and commercial).
|
| 108 |
"""
|
| 109 |
+
open_model_df = result_df[result_df['open_weight'] == True]
|
| 110 |
+
comm_model_df = result_df[result_df['open_weight'] == False]
|
| 111 |
|
| 112 |
open_model_df = open_model_df.sort_values(by='release_date', ascending=True)
|
| 113 |
comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True)
|
|
|
|
| 115 |
comm_models = populate_list(comm_model_df, comm_dip)
|
| 116 |
return open_models, comm_models
|
| 117 |
|
| 118 |
+
|
| 119 |
# Function to interpolate between two colors
|
| 120 |
def interpolate_color(rank_val, start_color):
|
| 121 |
"""
|
|
|
|
| 125 |
elif start_color == 'blue':
|
| 126 |
hue = 240
|
| 127 |
else:
|
| 128 |
+
raise KeyError(
|
| 129 |
+
f"Invalid color selected for trend graph: {start_color}. Please set either red or blue. Alternatively, set hue value in src.trend_utils.interpolate_colour")
|
| 130 |
+
|
| 131 |
+
saturation = rank_val * 100
|
| 132 |
value = 70 if rank_val == 1 else 100
|
| 133 |
|
| 134 |
return f"hsv({hue},{saturation},{value})"
|
|
|
|
| 145 |
pd.DataFrame: DataFrame containing processed model data.
|
| 146 |
"""
|
| 147 |
visited = set() # Track models that have been processed
|
| 148 |
+
result_df = pd.DataFrame(
|
| 149 |
+
columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag', 'version'])
|
| 150 |
|
| 151 |
text_dfs = text_data['dataframes']
|
| 152 |
for i in range(len(text_dfs)):
|
|
|
|
| 158 |
visited.add(model_name)
|
| 159 |
for dict_obj in model_registry_data:
|
| 160 |
if dict_obj["model_name"] == model_name:
|
| 161 |
+
if dict_obj["parameters"] == "":
|
| 162 |
params = "1000B"
|
| 163 |
est_flag = True
|
| 164 |
else:
|
|
|
|
| 166 |
est_flag = False
|
| 167 |
|
| 168 |
param_size = get_param_size(params)
|
| 169 |
+
|
| 170 |
+
new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i],
|
| 171 |
+
'open_weight': dict_obj['open_weight'],
|
| 172 |
+
'release_date': dict_obj['release_date'], 'parameters': param_size,
|
| 173 |
+
'est_flag': est_flag, 'version': version}
|
| 174 |
result_df.loc[len(result_df)] = new_data
|
| 175 |
break
|
| 176 |
+
|
| 177 |
return result_df # Return the compiled DataFrame
|
| 178 |
|
| 179 |
|
| 180 |
def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
|
| 181 |
+
benchmark_ticks: dict = {}, benchmark_update={}, **plot_kwargs) -> go.Figure:
|
| 182 |
"""Generate a scatter plot for the given DataFrame.
|
| 183 |
|
| 184 |
Args:
|
|
|
|
| 187 |
end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
|
| 188 |
benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
|
| 189 |
benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
|
| 190 |
+
|
| 191 |
Keyword Args:
|
| 192 |
open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
| 193 |
comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
|
|
|
| 210 |
|
| 211 |
# Filter out data before April 2023/START_DATE
|
| 212 |
df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
|
| 213 |
+
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
| 214 |
models_to_display = open_model_list + comm_model_list
|
| 215 |
|
| 216 |
# Create a column to indicate if the model should be labeled
|
|
|
|
| 223 |
versions = df['version'].unique()
|
| 224 |
version_names = sorted(
|
| 225 |
[ver for ver in versions],
|
| 226 |
+
key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
|
| 227 |
reverse=True
|
| 228 |
+
)
|
| 229 |
|
| 230 |
+
version_names = version_names[:3] # Select 3 latest benchmark versions
|
| 231 |
df = df[df['version'].isin(tuple(version_names))]
|
| 232 |
|
| 233 |
rank = 2
|
|
|
|
| 235 |
rank_value = {version_names[0]: 1}
|
| 236 |
for ver in version_names:
|
| 237 |
if ver not in rank_value:
|
| 238 |
+
rank_value[ver] = 1 - (rank - 1 - (max_rank / 15)) / (max_rank - 1)
|
| 239 |
rank += 1
|
| 240 |
|
| 241 |
df['color_value'] = df.apply(
|
|
|
|
| 253 |
for i in range(len(df)):
|
| 254 |
if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
|
| 255 |
if df.iloc[i]['open_weight']:
|
| 256 |
+
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
|
| 257 |
+
COLOUR_OPEN)
|
| 258 |
else:
|
| 259 |
+
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
|
| 260 |
+
COLOUR_COMM)
|
| 261 |
|
| 262 |
+
marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(
|
| 263 |
+
float) # Arbitrary sqrt value to scale marker size based on parameter size
|
| 264 |
|
| 265 |
# Create the scatter plot
|
| 266 |
fig = px.scatter(df,
|
| 267 |
+
x="Release Date (Model and & Benchmark Version)",
|
| 268 |
+
y="clemscore",
|
| 269 |
+
color="Model Type & Benchmark Version", # Differentiates the datasets by color
|
| 270 |
+
color_discrete_map=color_map, # Map colors to the defined subclasses
|
| 271 |
+
hover_name="model",
|
| 272 |
+
size=marker_size,
|
| 273 |
+
size_max=40, # Max size of the circles
|
| 274 |
+
template="plotly_white",
|
| 275 |
+
hover_data={ # Customize hover information
|
| 276 |
+
"Release Date (Model and & Benchmark Version)": True,
|
| 277 |
+
# Show the Release Date (Model and & Benchmark Version)
|
| 278 |
+
"clemscore": True, # Show the clemscore
|
| 279 |
+
"version": True
|
| 280 |
+
},
|
| 281 |
+
custom_data=["model", "Release Date (Model and & Benchmark Version)", "clemscore", "version"],
|
| 282 |
+
# Specify custom data columns for hover
|
| 283 |
+
opacity=0.8
|
| 284 |
+
)
|
| 285 |
|
| 286 |
fig.update_traces(
|
| 287 |
hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model): %{customdata[1]|%Y-%m-%d}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
|
| 288 |
)
|
| 289 |
+
|
| 290 |
# Sort dataframes for line plotting
|
| 291 |
df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
| 292 |
df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
|
|
|
| 302 |
|
| 303 |
## Benchmark Version ticks
|
| 304 |
benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
|
| 305 |
+
custom_ticks = {k: v for k, v in custom_ticks.items() if k not in benchmark_tickvals}
|
| 306 |
custom_tickvals = list(custom_ticks.keys())
|
| 307 |
|
| 308 |
if mobile_view:
|
| 309 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
| 310 |
one_month = pd.DateOffset(months=1)
|
| 311 |
filtered_custom_tickvals = [
|
| 312 |
+
date for date in custom_tickvals
|
| 313 |
+
if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in
|
| 314 |
+
benchmark_tickvals)
|
| 315 |
]
|
| 316 |
+
|
| 317 |
benchmark_tick_texts = []
|
| 318 |
for i in range(len(benchmark_tickvals)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
| 320 |
+
|
| 321 |
fig.update_xaxes(
|
| 322 |
tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
| 323 |
+
ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
|
| 324 |
+
benchmark_tick_texts, # Use the new benchmark tick texts
|
| 325 |
tickangle=0,
|
| 326 |
tickfont=dict(size=10)
|
| 327 |
)
|
| 328 |
+
fig.update_yaxes(range=[0,
|
| 329 |
+
110]) # Set y-axis range to 110 for better visibility of legend and avoiding overlap with interactivity block of plotly on top-right
|
| 330 |
display_mode = 'lines+markers'
|
| 331 |
else:
|
| 332 |
fig.update_xaxes(
|
| 333 |
tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
| 334 |
+
ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
|
| 335 |
+
[f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in
|
| 336 |
+
benchmark_tickvals], # Added <br> for vertical alignment
|
| 337 |
tickangle=0,
|
| 338 |
+
tickfont=dict(size=10)
|
| 339 |
)
|
| 340 |
+
fig.update_yaxes(range=[0, max_clemscore + 10])
|
| 341 |
display_mode = 'lines+markers+text'
|
| 342 |
|
|
|
|
| 343 |
# Add lines connecting the points for open models
|
| 344 |
fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
|
| 345 |
+
mode=display_mode, # Include 'text' in the mode
|
| 346 |
+
name='Open Models Trendline',
|
| 347 |
+
text=df_open['label_model'], # Use label_model for text labels
|
| 348 |
+
textposition='top center', # Position of the text labels
|
| 349 |
+
line=dict(color='red'), showlegend=False))
|
| 350 |
|
| 351 |
# Add lines connecting the points for commercial models
|
| 352 |
+
fig.add_trace(
|
| 353 |
+
go.Scatter(x=df_commercial['Release Date (Model and & Benchmark Version)'], y=df_commercial['clemscore'],
|
| 354 |
+
mode=display_mode, # Include 'text' in the mode
|
| 355 |
+
name='Commercial Models Trendline',
|
| 356 |
+
text=df_commercial['label_model'], # Use label_model for text labels
|
| 357 |
+
textposition='top center', # Position of the text labels
|
| 358 |
+
line=dict(color='blue'), showlegend=False))
|
| 359 |
+
|
| 360 |
+
# Update layout to ensure text labels are visible
|
| 361 |
fig.update_traces(textposition='top center')
|
| 362 |
|
| 363 |
# Update the Legend Position and plot dimensions
|
| 364 |
fig.update_layout(height=height,
|
| 365 |
+
legend=dict(
|
| 366 |
+
yanchor="top",
|
| 367 |
+
y=0.99,
|
| 368 |
+
xanchor="left",
|
| 369 |
+
x=0.01
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
|
| 373 |
if mobile_view:
|
| 374 |
if mobile_view:
|
|
|
|
| 383 |
bgcolor='rgba(255,255,255,0.7)', # semi-transparent white for mobile
|
| 384 |
bordercolor='rgba(0,0,0,0.05)')
|
| 385 |
)
|
|
|
|
|
|
|
| 386 |
|
| 387 |
return fig
|
| 388 |
|
| 389 |
+
|
| 390 |
def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
|
| 391 |
"""Fetch and generate the final trend plot for all models.
|
| 392 |
|
|
|
|
| 426 |
text_result_df = get_trend_data(text_data, model_registry_data)
|
| 427 |
## Get benchmark tickvalues as dates for X-axis
|
| 428 |
for ver in versions:
|
| 429 |
+
if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
|
| 430 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
|
| 431 |
if pd.to_datetime(ver['last_updated']) not in benchmark_update:
|
| 432 |
benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
|
| 433 |
else:
|
| 434 |
benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
|
| 435 |
|
| 436 |
+
fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'),
|
| 437 |
+
benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
| 438 |
else:
|
| 439 |
mm_data = get_github_data()['multimodal']
|
| 440 |
result_df = get_trend_data(mm_data, model_registry_data)
|
|
|
|
| 443 |
if 'multimodal' in ver['version']:
|
| 444 |
temp_ver = ver['version']
|
| 445 |
temp_ver = temp_ver.replace('_multimodal', '')
|
| 446 |
+
benchmark_ticks[
|
| 447 |
+
pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
|
| 448 |
benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
|
| 449 |
|
| 450 |
+
fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'),
|
| 451 |
+
benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
| 452 |
|
| 453 |
return fig
|