Koshti10 commited on
Commit
0004b1e
·
verified ·
1 Parent(s): 93eafdd

Upload 4 files

Browse files
Files changed (1) hide show
  1. src/trend_utils.py +94 -88
src/trend_utils.py CHANGED
@@ -17,6 +17,7 @@ START_DATE = '2023-06-01'
17
  COLOUR_OPEN = 'red'
18
  COLOUR_COMM = 'blue'
19
 
 
20
  def get_param_size(params: str) -> float:
21
  """Convert parameter size from string to float.
22
 
@@ -41,6 +42,7 @@ def get_param_size(params: str) -> float:
41
 
42
  return param_size
43
 
 
44
  def date_difference(date_str1: str, date_str2: str) -> int:
45
  """Calculate the difference in days between two dates.
46
 
@@ -104,8 +106,8 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
104
  Returns:
105
  tuple: Two lists of model names (open and commercial).
106
  """
107
- open_model_df = result_df[result_df['open_weight']==True]
108
- comm_model_df = result_df[result_df['open_weight']==False]
109
 
110
  open_model_df = open_model_df.sort_values(by='release_date', ascending=True)
111
  comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True)
@@ -113,6 +115,7 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
113
  comm_models = populate_list(comm_model_df, comm_dip)
114
  return open_models, comm_models
115
 
 
116
  # Function to interpolate between two colors
117
  def interpolate_color(rank_val, start_color):
118
  """
@@ -122,9 +125,10 @@ def interpolate_color(rank_val, start_color):
122
  elif start_color == 'blue':
123
  hue = 240
124
  else:
125
- raise KeyError(f"Invalid color selected for trend graph: {start_color}. Please set either red or blue. Alternatively, set hue value in src.trend_utils.interpolate_colour")
126
-
127
- saturation = rank_val*100
 
128
  value = 70 if rank_val == 1 else 100
129
 
130
  return f"hsv({hue},{saturation},{value})"
@@ -141,7 +145,8 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
141
  pd.DataFrame: DataFrame containing processed model data.
142
  """
143
  visited = set() # Track models that have been processed
144
- result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag', 'version'])
 
145
 
146
  text_dfs = text_data['dataframes']
147
  for i in range(len(text_dfs)):
@@ -153,7 +158,7 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
153
  visited.add(model_name)
154
  for dict_obj in model_registry_data:
155
  if dict_obj["model_name"] == model_name:
156
- if dict_obj["parameters"] == "" :
157
  params = "1000B"
158
  est_flag = True
159
  else:
@@ -161,17 +166,19 @@ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
161
  est_flag = False
162
 
163
  param_size = get_param_size(params)
164
-
165
- new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
166
- 'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag, 'version': version}
 
 
167
  result_df.loc[len(result_df)] = new_data
168
  break
169
-
170
  return result_df # Return the compiled DataFrame
171
 
172
 
173
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
174
- benchmark_ticks: dict = {}, benchmark_update = {}, **plot_kwargs) -> go.Figure:
175
  """Generate a scatter plot for the given DataFrame.
176
 
177
  Args:
@@ -180,7 +187,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
180
  end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
181
  benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
182
  benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
183
-
184
  Keyword Args:
185
  open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
186
  comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
@@ -203,7 +210,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
203
 
204
  # Filter out data before April 2023/START_DATE
205
  df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
206
- open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
207
  models_to_display = open_model_list + comm_model_list
208
 
209
  # Create a column to indicate if the model should be labeled
@@ -216,11 +223,11 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
216
  versions = df['version'].unique()
217
  version_names = sorted(
218
  [ver for ver in versions],
219
- key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
220
  reverse=True
221
- )
222
 
223
- version_names = version_names[:3] # Select 3 latest benchmark versions
224
  df = df[df['version'].isin(tuple(version_names))]
225
 
226
  rank = 2
@@ -228,7 +235,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
228
  rank_value = {version_names[0]: 1}
229
  for ver in version_names:
230
  if ver not in rank_value:
231
- rank_value[ver] = 1 - (rank-1-(max_rank/15))/(max_rank-1)
232
  rank += 1
233
 
234
  df['color_value'] = df.apply(
@@ -246,36 +253,40 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
246
  for i in range(len(df)):
247
  if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
248
  if df.iloc[i]['open_weight']:
249
- color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_OPEN)
 
250
  else:
251
- color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_COMM)
 
252
 
253
-
254
- marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float) # Arbitrary sqrt value to scale marker size based on parameter size
255
 
256
  # Create the scatter plot
257
  fig = px.scatter(df,
258
- x="Release Date (Model and & Benchmark Version)",
259
- y="clemscore",
260
- color="Model Type & Benchmark Version", # Differentiates the datasets by color
261
- color_discrete_map=color_map, # Map colors to the defined subclasses
262
- hover_name="model",
263
- size=marker_size,
264
- size_max=40, # Max size of the circles
265
- template="plotly_white",
266
- hover_data={ # Customize hover information
267
- "Release Date (Model and & Benchmark Version)": True, # Show the Release Date (Model and & Benchmark Version)
268
- "clemscore": True, # Show the clemscore
269
- "version": True
270
- },
271
- custom_data=["model", "Release Date (Model and & Benchmark Version)", "clemscore", "version"], # Specify custom data columns for hover
272
- opacity=0.8
273
- )
 
 
274
 
275
  fig.update_traces(
276
  hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model): %{customdata[1]|%Y-%m-%d}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
277
  )
278
-
279
  # Sort dataframes for line plotting
280
  df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
281
  df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
@@ -291,80 +302,73 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
291
 
292
  ## Benchmark Version ticks
293
  benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
294
- custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
295
  custom_tickvals = list(custom_ticks.keys())
296
 
297
  if mobile_view:
298
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
299
  one_month = pd.DateOffset(months=1)
300
  filtered_custom_tickvals = [
301
- date for date in custom_tickvals
302
- if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in benchmark_tickvals)
 
303
  ]
304
-
305
  benchmark_tick_texts = []
306
  for i in range(len(benchmark_tickvals)):
307
- # Alternate <br> for benchmark ticks based on date difference (Eg. v1.6, v1.6.5 too close to each other for MM benchmark)
308
- # if i == 0:
309
- # benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
310
- # else:
311
- # date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
312
- # if date_diff <= 75:
313
- # benchmark_tick_texts.append(f"<br><br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
314
- # else:
315
- # benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
316
  benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
317
-
318
  fig.update_xaxes(
319
  tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
320
- ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
321
- benchmark_tick_texts, # Use the new benchmark tick texts
322
  tickangle=0,
323
  tickfont=dict(size=10)
324
  )
325
- fig.update_yaxes(range=[0, 110]) # Set y-axis range to 110 for better visibility of legend and avoiding overlap with interactivity block of plotly on top-right
 
326
  display_mode = 'lines+markers'
327
  else:
328
  fig.update_xaxes(
329
  tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
330
- ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
331
- [f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in benchmark_tickvals], # Added <br> for vertical alignment
 
332
  tickangle=0,
333
- tickfont=dict(size=10)
334
  )
335
- fig.update_yaxes(range=[0, max_clemscore+10])
336
  display_mode = 'lines+markers+text'
337
 
338
-
339
  # Add lines connecting the points for open models
340
  fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
341
- mode=display_mode, # Include 'text' in the mode
342
- name='Open Models Trendline',
343
- text=df_open['label_model'], # Use label_model for text labels
344
- textposition='top center', # Position of the text labels
345
- line=dict(color='red'), showlegend=False))
346
 
347
  # Add lines connecting the points for commercial models
348
- fig.add_trace(go.Scatter(x=df_commercial['Release Date (Model and & Benchmark Version)'], y=df_commercial['clemscore'],
349
- mode=display_mode, # Include 'text' in the mode
350
- name='Commercial Models Trendline',
351
- text=df_commercial['label_model'], # Use label_model for text labels
352
- textposition='top center', # Position of the text labels
353
- line=dict(color='blue'), showlegend=False))
354
-
355
-
356
- # Update layout to ensure text labels are visible
357
  fig.update_traces(textposition='top center')
358
 
359
  # Update the Legend Position and plot dimensions
360
  fig.update_layout(height=height,
361
- legend=dict(
362
- yanchor="top",
363
- y=0.99,
364
- xanchor="left",
365
- x=0.01
366
- )
367
- )
368
 
369
  if mobile_view:
370
  if mobile_view:
@@ -379,11 +383,10 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
379
  bgcolor='rgba(255,255,255,0.7)', # semi-transparent white for mobile
380
  bordercolor='rgba(0,0,0,0.05)')
381
  )
382
-
383
-
384
 
385
  return fig
386
 
 
387
  def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
388
  """Fetch and generate the final trend plot for all models.
389
 
@@ -423,14 +426,15 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
423
  text_result_df = get_trend_data(text_data, model_registry_data)
424
  ## Get benchmark tickvalues as dates for X-axis
425
  for ver in versions:
426
- if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
427
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
428
  if pd.to_datetime(ver['last_updated']) not in benchmark_update:
429
  benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
430
  else:
431
  benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
432
 
433
- fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
 
434
  else:
435
  mm_data = get_github_data()['multimodal']
436
  result_df = get_trend_data(mm_data, model_registry_data)
@@ -439,9 +443,11 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
439
  if 'multimodal' in ver['version']:
440
  temp_ver = ver['version']
441
  temp_ver = temp_ver.replace('_multimodal', '')
442
- benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
 
443
  benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
444
 
445
- fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
 
446
 
447
  return fig
 
17
  COLOUR_OPEN = 'red'
18
  COLOUR_COMM = 'blue'
19
 
20
+
21
  def get_param_size(params: str) -> float:
22
  """Convert parameter size from string to float.
23
 
 
42
 
43
  return param_size
44
 
45
+
46
  def date_difference(date_str1: str, date_str2: str) -> int:
47
  """Calculate the difference in days between two dates.
48
 
 
106
  Returns:
107
  tuple: Two lists of model names (open and commercial).
108
  """
109
+ open_model_df = result_df[result_df['open_weight'] == True]
110
+ comm_model_df = result_df[result_df['open_weight'] == False]
111
 
112
  open_model_df = open_model_df.sort_values(by='release_date', ascending=True)
113
  comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True)
 
115
  comm_models = populate_list(comm_model_df, comm_dip)
116
  return open_models, comm_models
117
 
118
+
119
  # Function to interpolate between two colors
120
  def interpolate_color(rank_val, start_color):
121
  """
 
125
  elif start_color == 'blue':
126
  hue = 240
127
  else:
128
+ raise KeyError(
129
+ f"Invalid color selected for trend graph: {start_color}. Please set either red or blue. Alternatively, set hue value in src.trend_utils.interpolate_colour")
130
+
131
+ saturation = rank_val * 100
132
  value = 70 if rank_val == 1 else 100
133
 
134
  return f"hsv({hue},{saturation},{value})"
 
145
  pd.DataFrame: DataFrame containing processed model data.
146
  """
147
  visited = set() # Track models that have been processed
148
+ result_df = pd.DataFrame(
149
+ columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag', 'version'])
150
 
151
  text_dfs = text_data['dataframes']
152
  for i in range(len(text_dfs)):
 
158
  visited.add(model_name)
159
  for dict_obj in model_registry_data:
160
  if dict_obj["model_name"] == model_name:
161
+ if dict_obj["parameters"] == "":
162
  params = "1000B"
163
  est_flag = True
164
  else:
 
166
  est_flag = False
167
 
168
  param_size = get_param_size(params)
169
+
170
+ new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i],
171
+ 'open_weight': dict_obj['open_weight'],
172
+ 'release_date': dict_obj['release_date'], 'parameters': param_size,
173
+ 'est_flag': est_flag, 'version': version}
174
  result_df.loc[len(result_df)] = new_data
175
  break
176
+
177
  return result_df # Return the compiled DataFrame
178
 
179
 
180
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
181
+ benchmark_ticks: dict = {}, benchmark_update={}, **plot_kwargs) -> go.Figure:
182
  """Generate a scatter plot for the given DataFrame.
183
 
184
  Args:
 
187
  end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
188
  benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
189
  benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
190
+
191
  Keyword Args:
192
  open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
193
  comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
 
210
 
211
  # Filter out data before April 2023/START_DATE
212
  df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
213
+ open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
214
  models_to_display = open_model_list + comm_model_list
215
 
216
  # Create a column to indicate if the model should be labeled
 
223
  versions = df['version'].unique()
224
  version_names = sorted(
225
  [ver for ver in versions],
226
+ key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
227
  reverse=True
228
+ )
229
 
230
+ version_names = version_names[:3] # Select 3 latest benchmark versions
231
  df = df[df['version'].isin(tuple(version_names))]
232
 
233
  rank = 2
 
235
  rank_value = {version_names[0]: 1}
236
  for ver in version_names:
237
  if ver not in rank_value:
238
+ rank_value[ver] = 1 - (rank - 1 - (max_rank / 15)) / (max_rank - 1)
239
  rank += 1
240
 
241
  df['color_value'] = df.apply(
 
253
  for i in range(len(df)):
254
  if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
255
  if df.iloc[i]['open_weight']:
256
+ color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
257
+ COLOUR_OPEN)
258
  else:
259
+ color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'],
260
+ COLOUR_COMM)
261
 
262
+ marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(
263
+ float) # Arbitrary sqrt value to scale marker size based on parameter size
264
 
265
  # Create the scatter plot
266
  fig = px.scatter(df,
267
+ x="Release Date (Model and & Benchmark Version)",
268
+ y="clemscore",
269
+ color="Model Type & Benchmark Version", # Differentiates the datasets by color
270
+ color_discrete_map=color_map, # Map colors to the defined subclasses
271
+ hover_name="model",
272
+ size=marker_size,
273
+ size_max=40, # Max size of the circles
274
+ template="plotly_white",
275
+ hover_data={ # Customize hover information
276
+ "Release Date (Model and & Benchmark Version)": True,
277
+ # Show the Release Date (Model and & Benchmark Version)
278
+ "clemscore": True, # Show the clemscore
279
+ "version": True
280
+ },
281
+ custom_data=["model", "Release Date (Model and & Benchmark Version)", "clemscore", "version"],
282
+ # Specify custom data columns for hover
283
+ opacity=0.8
284
+ )
285
 
286
  fig.update_traces(
287
  hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model): %{customdata[1]|%Y-%m-%d}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
288
  )
289
+
290
  # Sort dataframes for line plotting
291
  df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
292
  df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
 
302
 
303
  ## Benchmark Version ticks
304
  benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
305
+ custom_ticks = {k: v for k, v in custom_ticks.items() if k not in benchmark_tickvals}
306
  custom_tickvals = list(custom_ticks.keys())
307
 
308
  if mobile_view:
309
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
310
  one_month = pd.DateOffset(months=1)
311
  filtered_custom_tickvals = [
312
+ date for date in custom_tickvals
313
+ if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in
314
+ benchmark_tickvals)
315
  ]
316
+
317
  benchmark_tick_texts = []
318
  for i in range(len(benchmark_tickvals)):
 
 
 
 
 
 
 
 
 
319
  benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
320
+
321
  fig.update_xaxes(
322
  tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
323
+ ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
324
+ benchmark_tick_texts, # Use the new benchmark tick texts
325
  tickangle=0,
326
  tickfont=dict(size=10)
327
  )
328
+ fig.update_yaxes(range=[0,
329
+ 110]) # Set y-axis range to 110 for better visibility of legend and avoiding overlap with interactivity block of plotly on top-right
330
  display_mode = 'lines+markers'
331
  else:
332
  fig.update_xaxes(
333
  tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
334
+ ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
335
+ [f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in
336
+ benchmark_tickvals], # Added <br> for vertical alignment
337
  tickangle=0,
338
+ tickfont=dict(size=10)
339
  )
340
+ fig.update_yaxes(range=[0, max_clemscore + 10])
341
  display_mode = 'lines+markers+text'
342
 
 
343
  # Add lines connecting the points for open models
344
  fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
345
+ mode=display_mode, # Include 'text' in the mode
346
+ name='Open Models Trendline',
347
+ text=df_open['label_model'], # Use label_model for text labels
348
+ textposition='top center', # Position of the text labels
349
+ line=dict(color='red'), showlegend=False))
350
 
351
  # Add lines connecting the points for commercial models
352
+ fig.add_trace(
353
+ go.Scatter(x=df_commercial['Release Date (Model and & Benchmark Version)'], y=df_commercial['clemscore'],
354
+ mode=display_mode, # Include 'text' in the mode
355
+ name='Commercial Models Trendline',
356
+ text=df_commercial['label_model'], # Use label_model for text labels
357
+ textposition='top center', # Position of the text labels
358
+ line=dict(color='blue'), showlegend=False))
359
+
360
+ # Update layout to ensure text labels are visible
361
  fig.update_traces(textposition='top center')
362
 
363
  # Update the Legend Position and plot dimensions
364
  fig.update_layout(height=height,
365
+ legend=dict(
366
+ yanchor="top",
367
+ y=0.99,
368
+ xanchor="left",
369
+ x=0.01
370
+ )
371
+ )
372
 
373
  if mobile_view:
374
  if mobile_view:
 
383
  bgcolor='rgba(255,255,255,0.7)', # semi-transparent white for mobile
384
  bordercolor='rgba(0,0,0,0.05)')
385
  )
 
 
386
 
387
  return fig
388
 
389
+
390
  def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
391
  """Fetch and generate the final trend plot for all models.
392
 
 
426
  text_result_df = get_trend_data(text_data, model_registry_data)
427
  ## Get benchmark tickvalues as dates for X-axis
428
  for ver in versions:
429
+ if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
430
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
431
  if pd.to_datetime(ver['last_updated']) not in benchmark_update:
432
  benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
433
  else:
434
  benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
435
 
436
+ fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'),
437
+ benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
438
  else:
439
  mm_data = get_github_data()['multimodal']
440
  result_df = get_trend_data(mm_data, model_registry_data)
 
443
  if 'multimodal' in ver['version']:
444
  temp_ver = ver['version']
445
  temp_ver = temp_ver.replace('_multimodal', '')
446
+ benchmark_ticks[
447
+ pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
448
  benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
449
 
450
+ fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'),
451
+ benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
452
 
453
  return fig