chriscanal commited on
Commit
36bf409
1 Parent(s): 8e47868

Updated plotted models to exclude flagged models

Browse files
Files changed (1) hide show
  1. src/display_models/plot_results.py +5 -1
src/display_models/plot_results.py CHANGED
@@ -4,6 +4,7 @@ from plotly.graph_objs import Figure
4
  import pickle
5
  from datetime import datetime, timezone
6
  from typing import List, Dict, Tuple, Any
 
7
 
8
  # Average ⬆️ human baseline is 0.897 (source: averaging human baselines below)
9
  # ARC human baseline is 0.80 (source: https://lab42.global/arc/)
@@ -42,6 +43,9 @@ def join_model_info_with_results(results_df: pd.DataFrame) -> pd.DataFrame:
42
  # copy dataframe to avoid modifying the original
43
  df = results_df.copy(deep=True)
44
 
 
 
 
45
  # load cache from disk
46
  try:
47
  with open("model_info_cache.pkl", "rb") as f:
@@ -216,4 +220,4 @@ def create_metric_plot_obj(
216
 
217
  # Example Usage:
218
  # human_baselines dictionary is defined.
219
- # chart = create_metric_plot_obj(scores_df, ["ARC", "HellaSwag", "MMLU", "TruthfulQA"], human_baselines, "Graph Title")
 
4
  import pickle
5
  from datetime import datetime, timezone
6
  from typing import List, Dict, Tuple, Any
7
+ from src.display_models.model_metadata_flags import FLAGGED_MODELS
8
 
9
  # Average ⬆️ human baseline is 0.897 (source: averaging human baselines below)
10
  # ARC human baseline is 0.80 (source: https://lab42.global/arc/)
 
43
  # copy dataframe to avoid modifying the original
44
  df = results_df.copy(deep=True)
45
 
46
+ # Filter out FLAGGED_MODELS to ensure graph is not skewed by mistakes
47
+ df = df[~df["model_name_for_query"].isin(FLAGGED_MODELS.keys())].reset_index(drop=True)
48
+
49
  # load cache from disk
50
  try:
51
  with open("model_info_cache.pkl", "rb") as f:
 
220
 
221
  # Example Usage:
222
  # human_baselines dictionary is defined.
223
+ # chart = create_metric_plot_obj(scores_df, ["ARC", "HellaSwag", "MMLU", "TruthfulQA"], human_baselines, "Graph Title")