natolambert commited on
Commit
0b8c16d
1 Parent(s): ab74236

upload plot

Browse files
Files changed (3) hide show
  1. app.py +6 -1
  2. src/plt.py +53 -0
  3. src/utils.py +12 -0
app.py CHANGED
@@ -5,6 +5,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
5
  from datasets import load_dataset
6
  from src.utils import load_all_data
7
  from src.md import ABOUT_TEXT, TOP_TEXT
 
8
  import numpy as np
9
 
10
  api = HfApi()
@@ -210,7 +211,11 @@ with gr.Blocks() as app:
210
  sample_display = gr.Markdown("{sampled data loads here}")
211
 
212
  button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
213
-
 
 
 
 
214
 
215
  # Load data when app starts, TODO make this used somewhere...
216
  # def load_data_on_start():
 
5
  from datasets import load_dataset
6
  from src.utils import load_all_data
7
  from src.md import ABOUT_TEXT, TOP_TEXT
8
+ from src.plt import plot_avg_correlation
9
  import numpy as np
10
 
11
  api = HfApi()
 
211
  sample_display = gr.Markdown("{sampled data loads here}")
212
 
213
  button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
214
+ # removed plot because not pretty enough
215
+ # with gr.TabItem("Model Correlation"):
216
+ # with gr.Row():
217
+ # plot = plot_avg_correlation(herm_data_avg, prefs_data)
218
+ # gr.Plot(plot)
219
 
220
  # Load data when app starts, TODO make this used somewhere...
221
  # def load_data_on_start():
src/plt.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pandas as pd
3
+ from .utils import undo_hyperlink
4
+
5
+ def plot_avg_correlation(df1, df2):
6
+ """
7
+ Plots the "average" column for each unique model that appears in both dataframes.
8
+
9
+ Parameters:
10
+ - df1: pandas DataFrame containing columns "model" and "average".
11
+ - df2: pandas DataFrame containing columns "model" and "average".
12
+ """
13
+ # Identify the unique models that appear in both DataFrames
14
+ common_models = pd.Series(list(set(df1['model']) & set(df2['model'])))
15
+
16
+ # Set up the plot
17
+ plt.figure(figsize=(13, 6), constrained_layout=True)
18
+
19
+ # axes from 0 to 1 for x and y
20
+ plt.xlim(0.475, 0.8)
21
+ plt.ylim(0.475, 0.8)
22
+
23
+ # larger font (16)
24
+ plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14})
25
+ # plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
26
+ # plt.tight_layout()
27
+ # plt.margins(0,0)
28
+
29
+ for model in common_models:
30
+ # Filter data for the current model
31
+ df1_model_data = df1[df1['model'] == model]['average'].values
32
+ df2_model_data = df2[df2['model'] == model]['average'].values
33
+
34
+ # Plotting
35
+ plt.scatter(df1_model_data, df2_model_data, label=model)
36
+ m_name = undo_hyperlink(model)
37
+ if m_name == "No text found":
38
+ m_name = "Random"
39
+ # Add text above each point like
40
+ # plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
41
+ plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center')
42
+
43
+ # add correlation line to scatter plot
44
+ # first, compute correlation
45
+ corr = df1['average'].corr(df2['average'])
46
+ # add correlation line based on corr
47
+
48
+
49
+
50
+ plt.xlabel('HERM Eval. Set Avg.', fontsize=16)
51
+ plt.ylabel('Pref. Test Sets Avg.', fontsize=16)
52
+ # plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
53
+ return plt
src/utils.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  from datasets import load_dataset
4
  import numpy as np
5
  import os
 
6
 
7
  # From Open LLM Leaderboard
8
  def model_hyperlink(link, model_name):
@@ -10,6 +11,17 @@ def model_hyperlink(link, model_name):
10
  return "random"
11
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Define a function to fetch and process data
14
  def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
15
  dir = Path(data_repo)
 
3
  from datasets import load_dataset
4
  import numpy as np
5
  import os
6
+ import re
7
 
8
  # From Open LLM Leaderboard
9
  def model_hyperlink(link, model_name):
 
11
  return "random"
12
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
13
 
14
+ def undo_hyperlink(html_string):
15
+ # Regex pattern to match content inside > and <
16
+ pattern = r'>[^<]+<'
17
+ match = re.search(pattern, html_string)
18
+ if match:
19
+ # Extract the matched text and remove leading '>' and trailing '<'
20
+ return match.group(0)[1:-1]
21
+ else:
22
+ return "No text found"
23
+
24
+
25
  # Define a function to fetch and process data
26
  def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
27
  dir = Path(data_repo)